Skip to content

knn

Functions:

Name Description
batched_knn

Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length)

batched_knn

batched_knn(
    embedding_index: Tensor,
    query: Tensor,
    k: int,
    p: int,
    max_batch_size: int | None = None,
    max_sequence_length: int | None = None,
    max_num_embeddings: int | None = None,
) -> torch.Tensor

Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length) section of the query a time.

Smaller values of max_batch_size, max_sequence_length, and max_num_embeddings require less memory to store the intermediate distance calculations but have longer runtimes.

Parameters:

Name Type Description Default

embedding_index

Tensor

A tensor of shape (n_embeddings, embedding_dim).

required

query

Tensor

A tensor of shape (batch_size, sequence_length, embedding_dim) or (sequence_length, embedding_dim).

required

k

int

The number of nearest neighbors to find.

required

p

int

The p-norm to use for the distance calculation.

required

max_batch_size

int | None

The maximum number of batch elements over which to calculate distances.

None

max_sequence_length

int | None

The maximum number of sequence positions over which to calculate distances.

None

max_num_embeddings

int | None

The maximum number of embeddings over which to calculate distances. The results from each split are recursively merged together.

None

Returns:

Type Description
torch.Tensor

A tensor of shape (batch_size, sequence_length, k) or (sequence_length, k) containing the indices of the k-nearest neighbors of each

torch.Tensor

query.

Raises:

Type Description
ValueError

If the input tensors are of dtype bfloat16.