Skip to content

knn

Functions:

Name Description
batched_knn

Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length)

batched_knn

batched_knn(
    embedding_index: Tensor,
    query: Tensor,
    k: int,
    max_batch_size: int | None = None,
    max_sequence_length: int | None = None,
    max_num_embeddings: int | None = None,
    dist_fn: Callable[
        Concatenate[Tensor, Tensor, P], Tensor
    ] = cdist,
    *dist_fn_args: args,
    **dist_fn_kwargs: kwargs,
) -> torch.Tensor

Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length) section of the query a time.

Smaller values of max_batch_size, max_sequence_length, and max_num_embeddings require less memory to store the intermediate distance calculations but have longer runtimes.

This function supports the use of arbitrary distance functions via "dist_fn". The default is to use the Euclidean distance via torch.cdist. Any trailing positional or keyword arguments which are not explicit arguments to this function are passed to dist_fn.

Parameters:

Name Type Description Default

embedding_index

Tensor

A tensor of shape (n_embeddings, embedding_dim).

required

query

Tensor

A tensor of shape (batch_size, sequence_length, embedding_dim) or (sequence_length, embedding_dim).

required

k

int

The number of nearest neighbors to find.

required

max_batch_size

int | None

The maximum number of batch elements over which to calculate distances.

None

max_sequence_length

int | None

The maximum number of sequence positions over which to calculate distances.

None

max_num_embeddings

int | None

The maximum number of embeddings over which to calculate distances. The results from each split are recursively merged together.

None

dist_fn

Callable[Concatenate[Tensor, Tensor, P], Tensor]

A callable that computes the distance between query and embedding_index.

cdist

dist_fn_args

args

Additional positional arguments to pass to dist_fn.

required

dist_fn_kwargs

kwargs

Additional keyword arguments to pass to dist_fn.

required

Returns:

Type Description
torch.Tensor

A tensor of shape (batch_size, sequence_length, k) or (sequence_length, k) containing the indices of the k-nearest neighbors of each

torch.Tensor

query.

Raises:

Type Description
ValueError

If the input tensors are of dtype bfloat16.

Changed in version v2.23.0: Refactored to allow arbitrary distance functions.