knn
Module for k-nearest neighbors (KNN) utilities in PyTorch.
Functions:
| Name | Description |
|---|---|
batched_knn |
Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a ( |
batched_knn
¶
batched_knn(
embedding_index: Tensor,
query: Tensor,
k: int,
max_batch_size: int | None = None,
max_sequence_length: int | None = None,
max_num_embeddings: int | None = None,
dist_fn: Callable[
Concatenate[Tensor, Tensor, P], Tensor
] = cdist,
*dist_fn_args: args,
**dist_fn_kwargs: kwargs
) -> torch.Tensor
Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length)
section of the query a time.
Smaller values of max_batch_size, max_sequence_length, and max_num_embeddings require less memory to store the intermediate
distance calculations but have longer runtimes.
This function supports the use of arbitrary distance functions via "dist_fn". The default is to use the Euclidean distance via
torch.cdist. Any trailing positional or keyword arguments which are not explicit arguments to this function are passed to dist_fn.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
A tensor of shape (n_embeddings, embedding_dim). |
required |
|
Tensor
|
A tensor of shape (batch_size, sequence_length, embedding_dim) or (sequence_length, embedding_dim). |
required |
|
int
|
The number of nearest neighbors to find. |
required |
|
int | None
|
The maximum number of batch elements over which to calculate distances. |
None
|
|
int | None
|
The maximum number of sequence positions over which to calculate distances. |
None
|
|
int | None
|
The maximum number of embeddings over which to calculate distances. The results from each split are recursively merged together. |
None
|
|
Callable[Concatenate[Tensor, Tensor, P], Tensor]
|
A callable that computes the distance between |
cdist
|
|
args
|
Additional positional arguments to pass to |
required |
|
kwargs
|
Additional keyword arguments to pass to |
required |
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
A tensor of shape (batch_size, sequence_length, k) or (sequence_length, k) containing the indices of the k-nearest neighbors of each |
torch.Tensor
|
query. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the input tensors are of dtype bfloat16. |
Changed in version v2.23.0: Refactored to allow arbitrary distance functions.