divergences
Module for f-divergence based loss functions.
Functions:
| Name | Description |
|---|---|
jefferys_divergence |
Compute the Jefferys divergence between two discrete probabilities using logits. |
jensen_shannon_divergence |
Compute the Jensen Shannon divergence between two discrete probabilities using logits. |
masked_cross_entropy |
Compute cross-entropy loss between input logits and target logits. |
masked_jefferys_divergence |
Compute Jefferys divergence between input logits and target logits. |
masked_kl_divergence |
Compute KL divergence between input logits and target logits. |
masked_unbiased_dcor |
Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables. |
squared_hellinger_distance |
Compute the Squared Hellinger distance between two discrete probabilities using logits. |
temperature_scaled_masked_jefferys_divergence |
Compute symmetric Jefferys divergence over masked positions: |
temperature_scaled_masked_kl_divergence |
Compute temperature-scaled |
total_variation |
Compute the total variation between two discrete probabilities using logits. |
jefferys_divergence
¶
jefferys_divergence(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the Jefferys divergence between two discrete probabilities using logits.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Note
If support_mask has non-zero elements then the divergence is taken from a masking
of the original-distributions, making it no longer a divergence of true distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The Jefferys divergence between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.
jensen_shannon_divergence
¶
jensen_shannon_divergence(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the Jensen Shannon divergence between two discrete probabilities using logits.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Note
If support_mask has non-zero elements then the divergence is taken from a masking
of the original-distributions, making it no longer a divergence of true distributions.
Note
This implementation has numerical stability issues due to the use of softmax to the mixed probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The Jensen Shannon divergence between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.
masked_cross_entropy
¶
masked_cross_entropy(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, max_loss: float | None = None, scaling_factor: float = 1.0) -> <class 'torch.Tensor'>
Compute cross-entropy loss between input logits and target logits.
Applies softmax to the target logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length) |
None
|
|
float | None
|
The maximum value for the loss to cap the cross entropy loss. |
None
|
|
float
|
A scaling factor for the sigmoid cross-entropy loss, lower values will slow the speed of optimization. Defaults to 1.0. |
1.0
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
Cross-entropy loss, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) cross_entropy_loss = masked_cross_entropy( ... input_logits, target_logits, attention_mask ... ) print(cross_entropy_loss.shape) torch.Size([])
Added in version v1.8.0.
Added in version v2.18.0.
masked_jefferys_divergence
¶
masked_jefferys_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = False) -> <class 'torch.Tensor'>
Compute Jefferys divergence between input logits and target logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length |
None
|
|
bool
|
Whether the target logits are already in log space. |
False
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
torch.Tensor: Jefferys divergence, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) jefferys_div = masked_jefferys_divergence( ... input_logits, target_logits, attention_mask, log_target=False ... ) print(jefferys_div.shape) torch.Size([])
masked_kl_divergence
¶
masked_kl_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = True) -> <class 'torch.Tensor'>
Compute KL divergence between input logits and target logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length) |
None
|
|
bool
|
Whether the target logits are already in log space. |
True
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
torch.Tensor: KL divergence, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) kl_div = masked_kl_divergence(input_logits, target_logits, attention_mask) print(kl_div.shape) torch.Size([])
Added in version v1.8.0.
masked_unbiased_dcor
¶
masked_unbiased_dcor(samples_1: Tensor, samples_2: Tensor, attention_mask: Tensor, safety_factor: float = 100.0) -> <class 'torch.Tensor'>
Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables.
Note
The approximation assumes the last last tensorial dimension is the random vector, and all preceding dimensions represent the different observations. In the case of text, this corresponds to a token level distance correlation calculation.
Note
The tensors must have the same number of rows, representing the number of samples.
See https://arxiv.org/pdf/1701.06054.pdf for more details.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The first tensor of samples. |
required |
|
Tensor
|
The second tensor of samples. |
required |
|
Tensor
|
An attention mask indicating valid samples. |
required |
|
float
|
A factor to scale the added epsilon for numerical stability. |
100.0
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
An approximation to the distance correlation between 0 and 1. |
Added in version v3.10.0.
squared_hellinger_distance
¶
squared_hellinger_distance(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the Squared Hellinger distance between two discrete probabilities using logits.
Computes the divergence between embeddings and averages across the sequence.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The squared Hellinger distance between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.
temperature_scaled_masked_jefferys_divergence
¶
temperature_scaled_masked_jefferys_divergence(teacher_logits: Tensor, student_logits: Tensor, position_mask: Tensor, temperature: float = 1.0) -> <class 'torch.Tensor'>
Compute symmetric Jefferys divergence over masked positions: KL(t||s) + KL(s||t).
Mirrors temperature_scaled_masked_kl_divergence
(mask-before-softmax — only active (M, V) rows are softmaxed at large vocab — and temperature
scaling) but adds the reverse-KL term so the divergence catches both mass-covering failures
(student missing teacher's modes) and mode-seeking failures (student putting mass where
teacher has none — e.g. mode-collapse onto a single response template). The forward-KL only
penalises the former; the reverse-KL only the latter; Jefferys penalises both.
Gradient routing is the caller's responsibility — this function does not detach either side.
If you want the standard distillation semantics (gradients flow only through the student),
pass teacher_logits.detach() at the call site.
For an empty mask returns a zero scalar on the teacher's device/dtype to keep composite losses safe to backpropagate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the clean / target model of shape |
required |
|
Tensor
|
Logits from the student / noisy model of shape |
required |
|
Tensor
|
Boolean mask of shape |
required |
|
float
|
Softmax temperature applied to both teacher and student logits. Defaults to |
1.0
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
Scalar Jefferys divergence averaged over active positions. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
teacher = torch.randn(2, 4, 16) student = torch.randn(2, 4, 16) mask = torch.zeros(2, 4, dtype=torch.bool) mask[:, -2:] = True jd = temperature_scaled_masked_jefferys_divergence(teacher, student, mask) jd.shape torch.Size([])
Added in version v3.43.0. Symmetric (forward + reverse KL) counterpart to `temperature_scaled_masked_kl_divergence` — penalises both mass-covering and mode-seeking failures. Filters by mask BEFORE the softmax and supports temperature scaling.
temperature_scaled_masked_kl_divergence
¶
temperature_scaled_masked_kl_divergence(teacher_logits: Tensor, student_logits: Tensor, position_mask: Tensor, temperature: float = 1.0) -> <class 'torch.Tensor'>
Compute temperature-scaled KL(softmax(teacher / T) || softmax(student / T)) over masked positions.
Gradient routing is the caller's responsibility — this function does not detach either side.
If you want the standard distillation semantics (gradients flow only through the student),
pass teacher_logits.detach() at the call site. Unlike
masked_kl_divergence, the position mask is applied before the softmax —
only active (M, V) rows are softmaxed, not the full (B, T, V). For large vocabularies (e.g. V > 100_000) this can shrink
the softmax working set by orders of magnitude.
For an empty mask (no active positions) the function returns a zero scalar on the teacher's device/dtype rather than NaN, so a
composite loss is safe to backpropagate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Logits from the clean / target model of shape |
required |
|
Tensor
|
Logits from the student / noisy model of shape |
required |
|
Tensor
|
Boolean mask of shape |
required |
|
float
|
Softmax temperature applied to both teacher and student logits. Defaults to |
1.0
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
Scalar KL divergence averaged over active positions. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
teacher = torch.randn(2, 4, 16) student = torch.randn(2, 4, 16) mask = torch.zeros(2, 4, dtype=torch.bool) mask[:, -2:] = True kl = temperature_scaled_masked_kl_divergence(teacher, student, mask) kl.shape torch.Size([])
Added in version v3.41.0. Logits-based KL distillation that filters by mask BEFORE the softmax — cuts the working set dramatically at large vocab — and supports temperature scaling. Complements the existing `masked_kl_divergence`.
Changed in version v3.43.0: No longer detaches `teacher_logits` internally — gradient routing is now the caller's responsibility. Pass `teacher_logits.detach()` for the previous distillation semantics. Integer 0/1 masks are now cast to bool before indexing, fixing silently-wrong (often zero) results when a long/int mask was passed instead of a bool mask.
total_variation
¶
total_variation(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>
Compute the total variation between two discrete probabilities using logits.
Computes the average tokenwise total variation between embeddings across a sequence.
Note
See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT data. |
required |
|
Tensor
|
The logits produced from data without the SGT. |
required |
|
Tensor
|
The attention mask for the batch. |
required |
|
Tensor | None
|
The mask placed on both noisy and clean logits. Useful for comparing top-k logits. |
None
|
|
Literal['mean', 'none']
|
Specifies how to reduce across the batch dimension after computing the per-sequence masked mean. |
'mean'
|
Returns:
| Type | Description |
|---|---|
<class 'torch.Tensor'>
|
The total variation between the noisy and clean logits, independently averaged over the batch and sequence lengths. |
Added in version v3.37.0.