Skip to content

divergences

Module for f-divergence based loss functions.

Functions:

Name Description
masked_cross_entropy

Compute cross-entropy loss between input logits and target logits.

masked_jefferys_divergence

Compute Jefferys divergence between input logits and target logits.

masked_kl_divergence

Compute KL divergence between input logits and target logits.

masked_unbiased_dcor

Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables.

masked_cross_entropy

masked_cross_entropy(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, max_loss: float | None = None, scaling_factor: float = 1.0) -> <class 'torch.Tensor'>

Compute cross-entropy loss between input logits and target logits.

Applies softmax to the target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length)

None

max_loss

float | None

The maximum value for the loss to cap the cross entropy loss.

None

scaling_factor

float

A scaling factor for the sigmoid cross-entropy loss, lower values will slow the speed of optimization. Defaults to 1.0.

1.0

Returns:

Type Description
<class 'torch.Tensor'>

Cross-entropy loss, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) cross_entropy_loss = masked_cross_entropy( ... input_logits, target_logits, attention_mask ... ) print(cross_entropy_loss.shape) torch.Size([])

Added in version v1.8.0.

Added in version v2.18.0.

masked_jefferys_divergence

masked_jefferys_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = False) -> <class 'torch.Tensor'>

Compute Jefferys divergence between input logits and target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length

None

log_target

bool

Whether the target logits are already in log space.

False

Returns:

Type Description
<class 'torch.Tensor'>

torch.Tensor: Jefferys divergence, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) jefferys_div = masked_jefferys_divergence( ... input_logits, target_logits, attention_mask, log_target=False ... ) print(jefferys_div.shape) torch.Size([])

masked_kl_divergence

masked_kl_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = True) -> <class 'torch.Tensor'>

Compute KL divergence between input logits and target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length)

None

log_target

bool

Whether the target logits are already in log space.

True

Returns:

Type Description
<class 'torch.Tensor'>

torch.Tensor: KL divergence, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) kl_div = masked_kl_divergence(input_logits, target_logits, attention_mask) print(kl_div.shape) torch.Size([])

Added in version v1.8.0.

masked_unbiased_dcor

masked_unbiased_dcor(samples_1: Tensor, samples_2: Tensor, attention_mask: Tensor, safety_factor: float = 100.0) -> <class 'torch.Tensor'>

Compute an unbiased approximation to the distance correlation between the tensors, representing samples of random variables.

Note

The approximation assumes the last last tensorial dimension is the random vector, and all preceding dimensions represent the different observations. In the case of text, this corresponds to a token level distance correlation calculation.

Note

The tensors must have the same number of rows, representing the number of samples.

See https://arxiv.org/pdf/1701.06054.pdf for more details.

Parameters:

Name Type Description Default

samples_1

Tensor

The first tensor of samples.

required

samples_2

Tensor

The second tensor of samples.

required

attention_mask

Tensor

An attention mask indicating valid samples.

required

safety_factor

float

A factor to scale the added epsilon for numerical stability.

100.0

Returns:

Type Description
<class 'torch.Tensor'>

An approximation to the distance correlation between 0 and 1.

Added in version v3.10.0.