8.20.8. sklearn.neighbors.BallTree¶
- class sklearn.neighbors.BallTree¶
Ball Tree for fast nearest-neighbor searches :
BallTree(X, leaf_size=20, p=2.0)
Parameters : X : array-like, shape = [n_samples, n_features]
n_samples is the number of points in the data set, and n_features is the dimension of the parameter space. Note: if X is a C-contiguous array of doubles then data will not be copied. Otherwise, an internal copy will be made.
leaf_size : positive integer (default = 20)
Number of points at which to switch to brute-force. Changing leaf_size will not affect the results of a query, but can significantly impact the speed of a query and the memory required to store the built ball tree. The amount of memory needed to store the tree scales as 2 ** (1 + floor(log2((n_samples - 1) / leaf_size))) - 1 For a specified leaf_size, a leaf node is guaranteed to satisfy leaf_size <= n_points <= 2 * leaf_size, except in the case that n_samples < leaf_size.
p : distance metric for the BallTree. p encodes the Minkowski
p-distance:
D = sum((X[i] - X[j]) ** p) ** (1. / p)
p must be greater than or equal to 1, so that the triangle inequality will hold. If p == np.inf, then the distance is equivalent to:
D = max(X[i] - X[j])
Examples
Query for k-nearest neighbors
>>> import numpy as np >>> np.random.seed(0) >>> X = np.random.random((10,3)) # 10 points in 3 dimensions >>> ball_tree = BallTree(X, leaf_size=2) >>> dist, ind = ball_tree.query(X[0], n_neighbors=3) >>> print ind # indices of 3 closest neighbors [0 3 1] >>> print dist # distances to 3 closest neighbors [ 0. 0.19662693 0.29473397]
Pickle and Unpickle a ball tree (using protocol = 2). Note that the state of the tree is saved in the pickle operation: the tree is not rebuilt on un-pickling
>>> import numpy as np >>> import pickle >>> np.random.seed(0) >>> X = np.random.random((10,3)) # 10 points in 3 dimensions >>> ball_tree = BallTree(X, leaf_size=2) >>> s = pickle.dumps(ball_tree, protocol=2) >>> ball_tree_copy = pickle.loads(s) >>> dist, ind = ball_tree_copy.query(X[0], k=3) >>> print ind # indices of 3 closest neighbors [0 3 1] >>> print dist # distances to 3 closest neighbors [ 0. 0.19662693 0.29473397]
Attributes
data warning_flag Methods
query(X[, k, return_distance]) query the Ball Tree for the k nearest neighbors query_radius query_radius(self, X, r, count_only = False): - __init__()¶
x.__init__(...) initializes x; see help(type(x)) for signature
- query(X, k=1, return_distance=True)¶
query the Ball Tree for the k nearest neighbors
Parameters : X : array-like, last dimension self.dim
An array of points to query
k : integer (default = 1)
The number of nearest neighbors to return
return_distance : boolean (default = True)
if True, return a tuple (d,i) if False, return array i
Returns : i : if return_distance == False
(d,i) : if return_distance == True
d : array of doubles - shape: x.shape[:-1] + (k,)
each entry gives the list of distances to the neighbors of the corresponding point (note that distances are not sorted)
i : array of integers - shape: x.shape[:-1] + (k,)
each entry gives the list of indices of neighbors of the corresponding point (note that neighbors are not sorted)
Examples
Query for k-nearest neighbors
>>> import numpy as np >>> np.random.seed(0) >>> X = np.random.random((10,3)) # 10 points in 3 dimensions >>> ball_tree = BallTree(X, leaf_size=2) >>> dist, ind = ball_tree.query(X[0], k=3) >>> print ind # indices of 3 closest neighbors [0 3 1] >>> print dist # distances to 3 closest neighbors [ 0. 0.19662693 0.29473397]
- query_radius()¶
query_radius(self, X, r, count_only = False):
query the Ball Tree for neighbors within a ball of size r
Parameters : X : array-like, last dimension self.dim
An array of points to query
r : distance within which neighbors are returned
r can be a single value, or an array of values of shape x.shape[:-1] if different radii are desired for each point.
return_distance : boolean (default = False)
if True, return distances to neighbors of each point if False, return only neighbors Note that unlike BallTree.query(), setting return_distance=True adds to the computation time. Not all distances need to be calculated explicitly for return_distance=False. Results are not sorted by default: see sort_results keyword.
count_only : boolean (default = False)
if True, return only the count of points within distance r if False, return the indices of all points within distance r If return_distance==True, setting count_only=True will result in an error.
sort_results : boolean (default = False)
if True, the distances and indices will be sorted before being returned. If False, the results will not be sorted. If return_distance == False, setting sort_results = True will result in an error.
Returns : count : if count_only == True
ind : if count_only == False and return_distance == False
(ind, dist) : if count_only == False and return_distance == True
count : array of integers, shape = X.shape[:-1]
each entry gives the number of neighbors within a distance r of the corresponding point.
ind : array of objects, shape = X.shape[:-1]
each element is a numpy integer array listing the indices of neighbors of the corresponding point. Note that unlike the results of BallTree.query(), the returned neighbors are not sorted by distance
dist : array of objects, shape = X.shape[:-1]
each element is a numpy double array listing the distances corresponding to indices in i.
Examples
Query for neighbors in a given radius
>>> import numpy as np >>> np.random.seed(0) >>> X = np.random.random((10,3)) # 10 points in 3 dimensions >>> ball_tree = BallTree(X, leaf_size=2) >>> print ball_tree.query_radius(X[0], r=0.3, count_only=True) 3 >>> ind = ball_tree.query_radius(X[0], r=0.3) >>> print ind # indices of neighbors within distance 0.3 [3 0 1]