divergences
Module for f-divergence based loss functions.
Functions:
| Name | Description |
|---|---|
jefferys_divergence |
Compute the Jefferys divergence between two discrete probabilities using logits. |
jensen_shannon_divergence |
Compute the Jensen Shannon divergence between two discrete probabilities using logits. |
masked_cross_entropy |
Compute cross-entropy loss between input logits and target logits. |
masked_jefferys_divergence |
Compute Jefferys divergence between input logits and target logits. |
masked_kl_divergence |
Compute KL divergence between input logits and target logits. |
masked_unbiased_dcor |
Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables. |
squared_hellinger_distance |
Compute the Squared Hellinger distance between two discrete probabilities using logits. |
temperature_scaled_masked_kl_divergence |
Compute temperature-scaled |
total_variation |
Compute the total variation between two discrete probabilities using logits. |
jefferys_divergence
¶
jefferys_divergence(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the Jefferys divergence between two discrete probabilities using logits.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Note
If support_mask has non-zero elements then the divergence is taken from a masking
of the original-distributions, making it no longer a divergence of true distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The Jefferys divergence between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.
jensen_shannon_divergence
¶
jensen_shannon_divergence(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the Jensen Shannon divergence between two discrete probabilities using logits.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Note
If support_mask has non-zero elements then the divergence is taken from a masking
of the original-distributions, making it no longer a divergence of true distributions.
Note
This implementation has numerical stability issues due to the use of softmax to the mixed probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The Jensen Shannon divergence between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.
masked_cross_entropy
¶
masked_cross_entropy(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, max_loss: float | None = None, scaling_factor: float = 1.0) -> <class 'torch.Tensor'>
Compute cross-entropy loss between input logits and target logits.
Applies softmax to the target logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length) |
None
|
|
float | None
|
The maximum value for the loss to cap the cross entropy loss. |
None
|
|
float
|
A scaling factor for the sigmoid cross-entropy loss, lower values will slow the speed of optimization. Defaults to 1.0. |
1.0
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
Cross-entropy loss, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) cross_entropy_loss = masked_cross_entropy( ... input_logits, target_logits, attention_mask ... ) print(cross_entropy_loss.shape) torch.Size([])
Added in version v1.8.0.
Added in version v2.18.0.
masked_jefferys_divergence
¶
masked_jefferys_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = False) -> <class 'torch.Tensor'>
Compute Jefferys divergence between input logits and target logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length |
None
|
|
bool
|
Whether the target logits are already in log space. |
False
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
torch.Tensor: Jefferys divergence, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) jefferys_div = masked_jefferys_divergence( ... input_logits, target_logits, attention_mask, log_target=False ... ) print(jefferys_div.shape) torch.Size([])
masked_kl_divergence
¶
masked_kl_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = True) -> <class 'torch.Tensor'>
Compute KL divergence between input logits and target logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length) |
None
|
|
bool
|
Whether the target logits are already in log space. |
True
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
torch.Tensor: KL divergence, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) kl_div = masked_kl_divergence(input_logits, target_logits, attention_mask) print(kl_div.shape) torch.Size([])
Added in version v1.8.0.
masked_unbiased_dcor
¶
masked_unbiased_dcor(samples_1: Tensor, samples_2: Tensor, attention_mask: Tensor, safety_factor: float = 100.0) -> <class 'torch.Tensor'>
Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables.
Note
The approximation assumes the last last tensorial dimension is the random vector, and all preceding dimensions represent the different observations. In the case of text, this corresponds to a token level distance correlation calculation.
Note
The tensors must have the same number of rows, representing the number of samples.
See https://arxiv.org/pdf/1701.06054.pdf for more details.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The first tensor of samples. |
required |
|
Tensor
|
The second tensor of samples. |
required |
|
Tensor
|
An attention mask indicating valid samples. |
required |
|
float
|
A factor to scale the added epsilon for numerical stability. |
100.0
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
An approximation to the distance correlation between 0 and 1. |
Added in version v3.10.0.
squared_hellinger_distance
¶
squared_hellinger_distance(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the Squared Hellinger distance between two discrete probabilities using logits.
Computes the divergence between embeddings and averages across the sequence.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The squared Hellinger distance between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.
temperature_scaled_masked_kl_divergence
¶
temperature_scaled_masked_kl_divergence(teacher_logits: Tensor, student_logits: Tensor, position_mask: Tensor, temperature: float = 1.0) -> <class 'torch.Tensor'>
Compute temperature-scaled KL(softmax(teacher / T) || softmax(student / T)) over masked positions.
The teacher tensor is detached so gradients flow only through the student. Unlike
masked_kl_divergence, the position mask is applied before the softmax —
only active (M, V) rows are softmaxed, not the full (B, T, V). For large vocabularies (e.g. V > 100_000) this can shrink
the softmax working set by orders of magnitude.
For an empty mask (no active positions) the function returns a zero scalar on the teacher's device/dtype rather than NaN, so a
composite loss is safe to backpropagate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the clean / target model of shape |
required |
|
Tensor
|
Logits from the student / noisy model of shape |
required |
|
Tensor
|
Boolean mask of shape |
required |
|
float
|
Softmax temperature applied to both teacher and student logits. Defaults to |
1.0
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
Scalar KL divergence averaged over active positions. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
teacher = torch.randn(2, 4, 16) student = torch.randn(2, 4, 16) mask = torch.zeros(2, 4, dtype=torch.bool) mask[:, -2:] = True kl = temperature_scaled_masked_kl_divergence(teacher, student, mask) kl.shape torch.Size([])
Added in version v3.41.0. Logits-based KL distillation that filters by mask BEFORE the softmax — cuts the working set dramatically at large vocab — and supports temperature scaling. Complements the existing `masked_kl_divergence`.
total_variation
¶
total_variation(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the total variation between two discrete probabilities using logits.
Computes the average tokenwise total variation between embeddings across a sequence.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The total variation between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.