knn
batched_knn
¶
batched_knn(embedding_index: Tensor, query: Tensor, k: int, p: int, max_batch_size: int | None = None, max_sequence_length: int | None = None, max_num_embeddings: int | None = None) -> torch.Tensor
Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size
, max_sequence_length
)
section of the query a time.
Smaller values of max_batch_size
, max_sequence_length
, and max_num_embeddings
require less memory to store the intermediate
distance calculations but have longer runtimes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embedding_index |
Tensor
|
A tensor of shape (n_embeddings, embedding_dim). |
required |
query |
Tensor
|
A tensor of shape (batch_size, sequence_length, embedding_dim) or (sequence_length, embedding_dim). |
required |
k |
int
|
The number of nearest neighbors to find. |
required |
p |
int
|
The p-norm to use for the distance calculation. |
required |
max_batch_size |
int | None
|
The maximum number of batch elements over which to calculate distances. |
None
|
max_sequence_length |
int | None
|
The maximum number of sequence positions over which to calculate distances. |
None
|
max_num_embeddings |
int | None
|
The maximum number of embeddings over which to calculate distances. The results from each split are recursively merged together. |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
A tensor of shape (batch_size, sequence_length, k) or (sequence_length, k) containing the indices of the k-nearest neighbors of each |
torch.Tensor
|
query. |
Raises:
Type | Description |
---|---|
ValueError
|
If the input tensors are of dtype bfloat16. |