knn
Functions:
| Name | Description | 
|---|---|
| batched_knn | Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a ( | 
batched_knn(
    embedding_index: Tensor,
    query: Tensor,
    k: int,
    p: int,
    max_batch_size: int | None = None,
    max_sequence_length: int | None = None,
    max_num_embeddings: int | None = None,
) -> torch.Tensor
Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length)
section of the query a time.
Smaller values of max_batch_size, max_sequence_length, and max_num_embeddings require less memory to store the intermediate
distance calculations but have longer runtimes.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Tensor | A tensor of shape (n_embeddings, embedding_dim). | required | 
|                    | Tensor | A tensor of shape (batch_size, sequence_length, embedding_dim) or (sequence_length, embedding_dim). | required | 
|                    | int | The number of nearest neighbors to find. | required | 
|                    | int | The p-norm to use for the distance calculation. | required | 
|                    | int | None | The maximum number of batch elements over which to calculate distances. | None | 
|                    | int | None | The maximum number of sequence positions over which to calculate distances. | None | 
|                    | int | None | The maximum number of embeddings over which to calculate distances. The results from each split are recursively merged together. | None | 
Returns:
| Type | Description | 
|---|---|
| torch.Tensor | A tensor of shape (batch_size, sequence_length, k) or (sequence_length, k) containing the indices of the k-nearest neighbors of each | 
| torch.Tensor | query. | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If the input tensors are of dtype bfloat16. |