This page

Citing

Please consider citing the scikit-learn.

9.4.8. sklearn.neighbors.BallTree

class sklearn.neighbors.BallTree

Ball Tree for fast nearest-neighbor searches :

BallTree(X, leaf_size=20, p=2.0)

Parameters :

X : array-like, shape = [n_samples, n_features]

n_samples is the number of points in the data set, and n_features is the dimension of the parameter space. Note: if X is a C-contiguous array of doubles then data will not be copied. Otherwise, an internal copy will be made.

leaf_size : positive integer (default = 20)

Number of points at which to switch to brute-force. Changing leaf_size will not affect the results of a query, but can significantly impact the speed of a query and the memory required to store the built ball tree. The amount of memory needed to store the tree scales as 2 ** (1 + floor(log2((n_samples - 1) / leaf_size))) - 1 For a specified leaf_size, a leaf node is guaranteed to satisfy leaf_size <= n_points <= 2 * leaf_size, except in the case that n_samples < leaf_size.

p : distance metric for the BallTree. p encodes the Minkowski

p-distance:

D = sum((X[i] - X[j]) ** p) ** (1. / p)

p must be greater than or equal to 1, so that the triangle inequality will hold. If p == np.inf, then the distance is equivalent to

D = max(X[i] - X[j])

Examples

Query for k-nearest neighbors

# >>> import numpy as np # >>> np.random.seed(0) # >>> X = np.random.random((10,3)) # 10 points in 3 dimensions # >>> ball_tree = BallTree(X, leaf_size=2) # >>> dist, ind = ball_tree.query(X[0], n_neighbors=3) # >>> print ind # indices of 3 closest neighbors # [0 3 1] # >>> print dist # distances to 3 closest neighbors # [ 0. 0.19662693 0.29473397]

Pickle and Unpickle a ball tree (using protocol = 2). Note that the state of the tree is saved in the pickle operation: the tree is not rebuilt on un-pickling

# >>> import numpy as np # >>> import pickle # >>> np.random.seed(0) # >>> X = np.random.random((10,3)) # 10 points in 3 dimensions # >>> ball_tree = BallTree(X, leaf_size=2) # >>> s = pickle.dumps(ball_tree, protocol=2) # >>> ball_tree_copy = pickle.loads(s) # >>> dist, ind = ball_tree_copy.query(X[0], k=3) # >>> print ind # indices of 3 closest neighbors # [0 3 1] # >>> print dist # distances to 3 closest neighbors # [ 0. 0.19662693 0.29473397]

Methods

query(X[, k, return_distance]) query the Ball Tree for the k nearest neighbors
query_radius query_radius(self, X, r, count_only = False):
__init__()

x.__init__(...) initializes x; see x.__class__.__doc__ for signature

query(X, k=1, return_distance=True)

query the Ball Tree for the k nearest neighbors

Parameters :

X : array-like, last dimension self.dim

An array of points to query

k : integer (default = 1)

The number of nearest neighbors to return

return_distance : boolean (default = True)

if True, return a tuple (d,i) if False, return array i

Returns :

i : if return_distance == False

(d,i) : if return_distance == True

d : array of doubles - shape: x.shape[:-1] + (k,)

each entry gives the list of distances to the neighbors of the corresponding point (note that distances are not sorted)

i : array of integers - shape: x.shape[:-1] + (k,)

each entry gives the list of indices of neighbors of the corresponding point (note that neighbors are not sorted)

Examples

Query for k-nearest neighbors

# >>> import numpy as np # >>> np.random.seed(0) # >>> X = np.random.random((10,3)) # 10 points in 3 dimensions # >>> ball_tree = BallTree(X, leaf_size=2) # >>> dist, ind = ball_tree.query(X[0], k=3) # >>> print ind # indices of 3 closest neighbors # [0 3 1] # >>> print dist # distances to 3 closest neighbors # [ 0. 0.19662693 0.29473397]
query_radius()

query_radius(self, X, r, count_only = False):

query the Ball Tree for neighbors within a ball of size r

Parameters :

X : array-like, last dimension self.dim

An array of points to query

r : distance within which neighbors are returned

r can be a single value, or an array of values of shape x.shape[:-1] if different radii are desired for each point.

return_distance : boolean (default = False)

if True, return distances to neighbors of each point if False, return only neighbors Note that unlike BallTree.query(), setting return_distance=True adds to the computation time. Not all distances need to be calculated explicitly for return_distance=False. Results are not sorted by default: see sort_results keyword.

count_only : boolean (default = False)

if True, return only the count of points within distance r if False, return the indices of all points within distance r If return_distance==True, setting count_only=True will result in an error.

sort_results : boolean (default = False)

if True, the distances and indices will be sorted before being returned. If False, the results will not be sorted. If return_distance == False, settinng sort_results = True will result in an error.

Returns :

count : if count_only == True

ind : if count_only == False and return_distance == False

(ind, dist) : if count_only == False and return_distance == True

count : array of integers, shape = X.shape[:-1]

each entry gives the number of neighbors within a distance r of the corresponding point.

ind : array of objects, shape = X.shape[:-1]

each element is a numpy integer array listing the indices of neighbors of the corresponding point. Note that unlike the results of BallTree.query(), the returned neighbors are not sorted by distance

dist : array of objects, shape = X.shape[:-1]

each element is a numpy double array listing the distances corresponding to indices in i.

Examples

Query for neighbors in a given radius

# >>> import numpy as np # >>> np.random.seed(0) # >>> X = np.random.random((10,3)) # 10 points in 3 dimensions # >>> ball_tree = BallTree(X, leaf_size=2) # >>> print ball_tree.query_radius(X[0], r=0.3, count_only=True) # 3 # >>> ind = ball_tree.query_radius(X[0], r=0.3) # >>> print ind # indices of neighbors within distance 0.3 # [3 0 1]