Skip to content

knn

batched_knn

batched_knn(embedding_index: Tensor, query: Tensor, k: int, p: int, max_batch_size: int | None = None, max_sequence_length: int | None = None, max_num_embeddings: int | None = None) -> torch.Tensor

Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length) section of the query a time.

Smaller values of max_batch_size, max_sequence_length, and max_num_embeddings require less memory to store the intermediate distance calculations but have longer runtimes.

Parameters:

Name Type Description Default
embedding_index Tensor

A tensor of shape (n_embeddings, embedding_dim).

required
query Tensor

A tensor of shape (batch_size, sequence_length, embedding_dim) or (sequence_length, embedding_dim).

required
k int

The number of nearest neighbors to find.

required
p int

The p-norm to use for the distance calculation.

required
max_batch_size int | None

The maximum number of batch elements over which to calculate distances.

None
max_sequence_length int | None

The maximum number of sequence positions over which to calculate distances.

None
max_num_embeddings int | None

The maximum number of embeddings over which to calculate distances. The results from each split are recursively merged together.

None

Returns:

Type Description
torch.Tensor

A tensor of shape (batch_size, sequence_length, k) or (sequence_length, k) containing the indices of the k-nearest neighbors of each

torch.Tensor

query.

Raises:

Type Description
ValueError

If the input tensors are of dtype bfloat16.