Skip to content

cosine

Module for cosine similarity and distance loss functions.

Functions:

Name Description
batched_normalized_cosine_dist

Compute the normalized cosine distance between query and embedding_index pairwise.

normalized_cosine_distance

Calculate the cosine distance (negative cosine similarity) between two tensors, scaled and shifted into the range [0, 1].

normalized_cosine_similarity

Calculate the cosine similarity between two tensors, scaled and shifted into the range [0, 1].

vision_cosine_distillation_loss

Compute mean cosine-distance loss between teacher (clean) and student (noisy) features.

absolute_cosine_similarity

absolute_cosine_similarity(
    x0: Tensor, x1: Tensor, noise_mask: Tensor | None = None
) -> torch.Tensor

Calculate the absolute cosine similarity between two tensors, masked by a noise mask.

When used as a loss it encourages the two tensors to be orthogonal.

Parameters:

Name Type Description Default

x0

Tensor

The first tensor.

required

x1

Tensor

The second tensor.

required

noise_mask

Tensor | None

A boolean mask indicating which elements to include in the calculation.

None

Returns:

Type Description
torch.Tensor

The mean absolute cosine similarity between the two tensors, masked by the noise mask.

batched_normalized_cosine_dist

batched_normalized_cosine_dist(
    query: Tensor, embedding_index: Tensor, p: int = 2
) -> torch.Tensor

Compute the normalized cosine distance between query and embedding_index pairwise.

Note: We choose to use the square root in the implementation to ensure the implementation is a valid distance metric.

Parameters:

Name Type Description Default

query

Tensor

An n-dimensional tensor of shape (*, embedding_dim).

required

embedding_index

Tensor

A tensor of shape (n_embeddings, embedding_dim).

required

p

int

The p-norm to use for normalization. Defaults to 2 for standard Euclidean normalization.

2

Returns:

Type Description
torch.Tensor

A tensor of shape (*, n_embeddings) containing the normalized cosine distances between the input tensors.

Examples:

>>> query = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
>>> embedding_index = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
>>> batched_normalized_cosine_dist(query, embedding_index)
tensor([[0.0000, 0.7071],
        [0.7071, 0.0000]])

Added in version v2.23.0. Added batched normalized cosine distance function.

normalized_cosine_distance

normalized_cosine_distance(
    x1: Tensor, x2: Tensor, dim: int = 1, eps: float = 1e-08
) -> torch.Tensor

Calculate the cosine distance (negative cosine similarity) between two tensors, scaled and shifted into the range [0, 1].

Parameters:

Name Type Description Default

x1

Tensor

The first tensor.

required

x2

Tensor

The second tensor.

required

dim

int

The dimension along which cosine distance is computed.

1

eps

float

A small value to prevent division by zero.

1e-08

Returns:

Type Description
torch.Tensor

The cosine distance of the tensors, scaled and shifted to between 0 and 1.

normalized_cosine_similarity

normalized_cosine_similarity(
    x1: Tensor, x2: Tensor, dim: int = 1, eps: float = 1e-08
) -> torch.Tensor

Calculate the cosine similarity between two tensors, scaled and shifted into the range [0, 1].

Parameters:

Name Type Description Default

x1

Tensor

The first tensor.

required

x2

Tensor

The second tensor.

required

dim

int

The dimension along which cosine similarity is computed.

1

eps

float

A small value to prevent division by zero.

1e-08

Returns:

Type Description
torch.Tensor

The cosine similarity of the tensors, scaled and shifted to between 0 and 1.

vision_cosine_distillation_loss

vision_cosine_distillation_loss(
    teacher: Tensor, student: Tensor
) -> torch.Tensor

Compute mean cosine-distance loss between teacher (clean) and student (noisy) features.

Both tensors must share leading shape. Cosine similarity is taken along the last (feature) axis, then mean-reduced over all preceding axes. The teacher tensor is detached so gradients flow only through the student path; both tensors are promoted to float32 to keep the cosine numerically stable when the upstream model runs in bfloat16.

Typical use case: when training a vision-tower-aware Stained Glass Transform, the post-merger output of the vision encoder is the per-image-token feature stream spliced into the LLM's input embeddings. Pulling the noisy branch's stream toward the clean branch's adds a self-distillation signal that complements the language-head cross-entropy / KL — privacy bounds (MS-SSIM, std-log) keep the cloak from collapsing to identity.

Privacy note: student carries gradients into the Stained Glass Transform parameters. teacher is detached and contributes only as a target.

Parameters:

Name Type Description Default

teacher

Tensor

Clean-branch features of shape (..., feature_dim). Detached internally.

required

student

Tensor

Noisy-branch features of the same shape as teacher.

required

Returns:

Type Description
torch.Tensor

Scalar mean cosine distance 1 - cos(teacher, student) in [0, 2].

Example

teacher = torch.randn(2, 4, 8)

Identical features yield a distance at the float32 precision floor.

vision_cosine_distillation_loss(teacher, teacher.clone()).item() < 1e-6 True

Added in version v3.41.0. Self-distillation loss between a detached teacher and a student feature stream. Computes in float32 for bf16 stability and averages over all dimensions preceding the feature axis.