knn
find_knn
¶
Find the k-nearest neighbors of a query tensor in an embedding index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embedding_index |
Tensor
|
A tensor of shape (n, d) containing n embeddings of dimension d. |
required |
query |
Tensor
|
A tensor of shape (m, d) containing m queries of dimension d. |
required |
k |
int
|
The number of nearest neighbors to find. |
required |
p |
int
|
The p-norm to use for the distance calculation. |
required |
Returns:
Type | Description |
---|---|
<class 'torch.Tensor'>
|
A tensor of shape (m, k) containing the indices of the k-nearest neighbors of each query. |
Raises:
Type | Description |
---|---|
ValueError
|
If the input tensors are of dtype BFloat16. |