Skip to content

divergences

Module for f-divergence based loss functions.

Functions:

Name Description
jefferys_divergence

Compute the Jefferys divergence between two discrete probabilities using logits.

jensen_shannon_divergence

Compute the Jensen Shannon divergence between two discrete probabilities using logits.

masked_cross_entropy

Compute cross-entropy loss between input logits and target logits.

masked_jefferys_divergence

Compute Jefferys divergence between input logits and target logits.

masked_kl_divergence

Compute KL divergence between input logits and target logits.

masked_unbiased_dcor

Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables.

squared_hellinger_distance

Compute the Squared Hellinger distance between two discrete probabilities using logits.

total_variation

Compute the total variation between two discrete probabilities using logits.

jefferys_divergence

jefferys_divergence(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>

Compute the Jefferys divergence between two discrete probabilities using logits.

Note

See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.

Note

If support_mask has non-zero elements then the divergence is taken from a masking of the original-distributions, making it no longer a divergence of true distributions.

Parameters:

Name Type Description Default

noisy_logits

Tensor

The logits produced from SGT data.

required

clean_logits

Tensor

The logits produced from data without the SGT.

required

attention_mask

Tensor

The attention mask for the batch.

required

support_mask

Tensor | None

The mask placed on both noisy and clean logits. Useful for comparing top-k logits.

None

reduction

Literal['mean', 'none']

Specifies how to reduce across the batch dimension after computing the per-sequence masked mean.

'mean'

Returns:

Type Description
<class 'torch.Tensor'>

The Jefferys divergence between the noisy and clean logits, independently averaged over the batch and sequence lengths.

Added in version v3.37.0.

jensen_shannon_divergence

jensen_shannon_divergence(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>

Compute the Jensen Shannon divergence between two discrete probabilities using logits.

Note

See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.

Note

If support_mask has non-zero elements then the divergence is taken from a masking of the original-distributions, making it no longer a divergence of true distributions.

Note

This implementation has numerical stability issues due to the use of softmax to the mixed probabilities.

Parameters:

Name Type Description Default

noisy_logits

Tensor

The logits produced from SGT data.

required

clean_logits

Tensor

The logits produced from data without the SGT.

required

attention_mask

Tensor

The attention mask for the batch.

required

support_mask

Tensor | None

The mask placed on both noisy and clean logits. Useful for comparing top-k logits.

None

reduction

Literal['mean', 'none']

Specifies how to reduce across the batch dimension after computing the per-sequence masked mean.

'mean'

Returns:

Type Description
<class 'torch.Tensor'>

The Jensen Shannon divergence between the noisy and clean logits, independently averaged over the batch and sequence lengths.

Added in version v3.37.0.

masked_cross_entropy

masked_cross_entropy(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, max_loss: float | None = None, scaling_factor: float = 1.0) -> <class 'torch.Tensor'>

Compute cross-entropy loss between input logits and target logits.

Applies softmax to the target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length)

None

max_loss

float | None

The maximum value for the loss to cap the cross entropy loss.

None

scaling_factor

float

A scaling factor for the sigmoid cross-entropy loss, lower values will slow the speed of optimization. Defaults to 1.0.

1.0

Returns:

Type Description
<class 'torch.Tensor'>

Cross-entropy loss, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) cross_entropy_loss = masked_cross_entropy( ... input_logits, target_logits, attention_mask ... ) print(cross_entropy_loss.shape) torch.Size([])

Added in version v1.8.0.

Added in version v2.18.0.

masked_jefferys_divergence

masked_jefferys_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = False) -> <class 'torch.Tensor'>

Compute Jefferys divergence between input logits and target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length

None

log_target

bool

Whether the target logits are already in log space.

False

Returns:

Type Description
<class 'torch.Tensor'>

torch.Tensor: Jefferys divergence, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) jefferys_div = masked_jefferys_divergence( ... input_logits, target_logits, attention_mask, log_target=False ... ) print(jefferys_div.shape) torch.Size([])

masked_kl_divergence

masked_kl_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = True) -> <class 'torch.Tensor'>

Compute KL divergence between input logits and target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length)

None

log_target

bool

Whether the target logits are already in log space.

True

Returns:

Type Description
<class 'torch.Tensor'>

torch.Tensor: KL divergence, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) kl_div = masked_kl_divergence(input_logits, target_logits, attention_mask) print(kl_div.shape) torch.Size([])

Added in version v1.8.0.

masked_unbiased_dcor

masked_unbiased_dcor(samples_1: Tensor, samples_2: Tensor, attention_mask: Tensor, safety_factor: float = 100.0) -> <class 'torch.Tensor'>

Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables.

Note

The approximation assumes the last last tensorial dimension is the random vector, and all preceding dimensions represent the different observations. In the case of text, this corresponds to a token level distance correlation calculation.

Note

The tensors must have the same number of rows, representing the number of samples.

See https://arxiv.org/pdf/1701.06054.pdf for more details.

Parameters:

Name Type Description Default

samples_1

Tensor

The first tensor of samples.

required

samples_2

Tensor

The second tensor of samples.

required

attention_mask

Tensor

An attention mask indicating valid samples.

required

safety_factor

float

A factor to scale the added epsilon for numerical stability.

100.0

Returns:

Type Description
<class 'torch.Tensor'>

An approximation to the distance correlation between 0 and 1.

Added in version v3.10.0.

squared_hellinger_distance

squared_hellinger_distance(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>

Compute the Squared Hellinger distance between two discrete probabilities using logits.

Computes the divergence between embeddings and averages across the sequence.

Note

See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.

Parameters:

Name Type Description Default

noisy_logits

Tensor

The logits produced from SGT data.

required

clean_logits

Tensor

The logits produced from data without the SGT.

required

attention_mask

Tensor

The attention mask for the batch.

required

support_mask

Tensor | None

The mask placed on both noisy and clean logits. Useful for comparing top-k logits.

None

reduction

Literal['mean', 'none']

Specifies how to reduce across the batch dimension after computing the per-sequence masked mean.

'mean'

Returns:

Type Description
<class 'torch.Tensor'>

The squared Hellinger distance between the noisy and clean logits, independently averaged over the batch and sequence lengths.

Added in version v3.37.0.

total_variation

total_variation(noisy_logits: Tensor, clean_logits: Tensor, attention_mask: Tensor, support_mask: Tensor | None = None, reduction: Literal['mean', 'none'] = 'mean') -> <class 'torch.Tensor'>

Compute the total variation between two discrete probabilities using logits.

Computes the average tokenwise total variation between embeddings across a sequence.

Note

See https://en.wikipedia.org/wiki/F-divergence#Common_examples_of_f-divergences for the implementation formula.

Parameters:

Name Type Description Default

noisy_logits

Tensor

The logits produced from SGT data.

required

clean_logits

Tensor

The logits produced from data without the SGT.

required

attention_mask

Tensor

The attention mask for the batch.

required

support_mask

Tensor | None

The mask placed on both noisy and clean logits. Useful for comparing top-k logits.

None

reduction

Literal['mean', 'none']

Specifies how to reduce across the batch dimension after computing the per-sequence masked mean.

'mean'

Returns:

Type Description
<class 'torch.Tensor'>

The total variation between the noisy and clean logits, independently averaged over the batch and sequence lengths.

Added in version v3.37.0.