Skip to content

knn

find_knn

find_knn(embedding_index: Tensor, query: Tensor, k: int, p: int) -> <class 'torch.Tensor'>

Find the k-nearest neighbors of a query tensor in an embedding index.

Parameters:

Name Type Description Default
embedding_index Tensor

A tensor of shape (n, d) containing n embeddings of dimension d.

required
query Tensor

A tensor of shape (m, d) containing m queries of dimension d.

required
k int

The number of nearest neighbors to find.

required
p int

The p-norm to use for the distance calculation.

required

Returns:

Type Description
<class 'torch.Tensor'>

A tensor of shape (m, k) containing the indices of the k-nearest neighbors of each query.

Raises:

Type Description
ValueError

If the input tensors are of dtype BFloat16.