This documentation is for scikit-learn version 0.11-gitOther versions

Citing

If you use the software, please consider citing scikit-learn.

This page

8.21.6. sklearn.neighbors.BallTree

class sklearn.neighbors.BallTree

Ball Tree for fast nearest-neighbor searches :

BallTree(X, leaf_size=20, p=2.0)

Parameters :

X : array-like, shape = [n_samples, n_features]

n_samples is the number of points in the data set, and n_features is the dimension of the parameter space. Note: if X is a C-contiguous array of doubles then data will not be copied. Otherwise, an internal copy will be made.

leaf_size : positive integer (default = 20)

Number of points at which to switch to brute-force. Changing leaf_size will not affect the results of a query, but can significantly impact the speed of a query and the memory required to store the built ball tree. The amount of memory needed to store the tree scales as 2 ** (1 + floor(log2((n_samples - 1) / leaf_size))) - 1 For a specified leaf_size, a leaf node is guaranteed to satisfy leaf_size <= n_points <= 2 * leaf_size, except in the case that n_samples < leaf_size.

p : distance metric for the BallTree. p encodes the Minkowski

p-distance:

D = sum((X[i] - X[j]) ** p) ** (1. / p)

p must be greater than or equal to 1, so that the triangle inequality will hold. If p == np.inf, then the distance is equivalent to:

D = max(X[i] - X[j])

Examples

Query for k-nearest neighbors

>>> import numpy as np
>>> np.random.seed(0)
>>> X = np.random.random((10,3))  # 10 points in 3 dimensions
>>> ball_tree = BallTree(X, leaf_size=2)              
>>> dist, ind = ball_tree.query(X[0], n_neighbors=3)  
>>> print ind  # indices of 3 closest neighbors
[0 3 1]
>>> print dist  # distances to 3 closest neighbors
[ 0.          0.19662693  0.29473397]

Pickle and Unpickle a ball tree (using protocol = 2). Note that the state of the tree is saved in the pickle operation: the tree is not rebuilt on un-pickling

>>> import numpy as np
>>> import pickle
>>> np.random.seed(0)
>>> X = np.random.random((10,3))  # 10 points in 3 dimensions
>>> ball_tree = BallTree(X, leaf_size=2)          
>>> s = pickle.dumps(ball_tree, protocol=2)       
>>> ball_tree_copy = pickle.loads(s)              
>>> dist, ind = ball_tree_copy.query(X[0], k=3)   
>>> print ind  # indices of 3 closest neighbors   
[0 3 1]
>>> print dist  # distances to 3 closest neighbors
[ 0.          0.19662693  0.29473397]

Attributes

data
warning_flag

Methods

query(X[, k, return_distance]) query the Ball Tree for the k nearest neighbors
query_radius query_radius(self, X, r, count_only = False):
__init__()

x.__init__(...) initializes x; see help(type(x)) for signature

query(X, k=1, return_distance=True)

query the Ball Tree for the k nearest neighbors

Parameters :

X : array-like, last dimension self.dim

An array of points to query

k : integer (default = 1)

The number of nearest neighbors to return

return_distance : boolean (default = True)

if True, return a tuple (d,i) if False, return array i

Returns :

i : if return_distance == False

(d,i) : if return_distance == True

d : array of doubles - shape: x.shape[:-1] + (k,)

each entry gives the list of distances to the neighbors of the corresponding point (note that distances are not sorted)

i : array of integers - shape: x.shape[:-1] + (k,)

each entry gives the list of indices of neighbors of the corresponding point (note that neighbors are not sorted)

Examples

Query for k-nearest neighbors

>>> import numpy as np
>>> np.random.seed(0)
>>> X = np.random.random((10,3))  # 10 points in 3 dimensions
>>> ball_tree = BallTree(X, leaf_size=2)    
>>> dist, ind = ball_tree.query(X[0], k=3)  
>>> print ind  # indices of 3 closest neighbors
[0 3 1]
>>> print dist  # distances to 3 closest neighbors
[ 0.          0.19662693  0.29473397]
query_radius()

query_radius(self, X, r, count_only = False):

query the Ball Tree for neighbors within a ball of size r

Parameters :

X : array-like, last dimension self.dim

An array of points to query

r : distance within which neighbors are returned

r can be a single value, or an array of values of shape x.shape[:-1] if different radii are desired for each point.

return_distance : boolean (default = False)

if True, return distances to neighbors of each point if False, return only neighbors Note that unlike BallTree.query(), setting return_distance=True adds to the computation time. Not all distances need to be calculated explicitly for return_distance=False. Results are not sorted by default: see sort_results keyword.

count_only : boolean (default = False)

if True, return only the count of points within distance r if False, return the indices of all points within distance r If return_distance==True, setting count_only=True will result in an error.

sort_results : boolean (default = False)

if True, the distances and indices will be sorted before being returned. If False, the results will not be sorted. If return_distance == False, setting sort_results = True will result in an error.

Returns :

count : if count_only == True

ind : if count_only == False and return_distance == False

(ind, dist) : if count_only == False and return_distance == True

count : array of integers, shape = X.shape[:-1]

each entry gives the number of neighbors within a distance r of the corresponding point.

ind : array of objects, shape = X.shape[:-1]

each element is a numpy integer array listing the indices of neighbors of the corresponding point. Note that unlike the results of BallTree.query(), the returned neighbors are not sorted by distance

dist : array of objects, shape = X.shape[:-1]

each element is a numpy double array listing the distances corresponding to indices in i.

Examples

Query for neighbors in a given radius

>>> import numpy as np
>>> np.random.seed(0)
>>> X = np.random.random((10,3))  # 10 points in 3 dimensions
>>> ball_tree = BallTree(X, leaf_size=2)        
>>> print ball_tree.query_radius(X[0], r=0.3, count_only=True)
3
>>> ind = ball_tree.query_radius(X[0], r=0.3)  
>>> print ind  # indices of neighbors within distance 0.3
[3 0 1]