Skip to content

divergences

Functions:

Name Description
masked_cross_entropy

Compute cross-entropy loss between input logits and target logits.

masked_jefferys_divergence

Compute Jefferys divergence between input logits and target logits.

masked_kl_divergence

Compute KL divergence between input logits and target logits.

masked_cross_entropy

masked_cross_entropy(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None) -> <class 'torch.Tensor'>

Compute cross-entropy loss between input logits and target logits.

Applies softmax to the target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length)

None

Returns:

Type Description
<class 'torch.Tensor'>

Cross-entropy loss, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) cross_entropy_loss = masked_cross_entropy( ... input_logits, target_logits, attention_mask ... ) print(cross_entropy_loss.shape) torch.Size([])

Added in version v1.8.0.

masked_jefferys_divergence

masked_jefferys_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = False) -> <class 'torch.Tensor'>

Compute Jefferys divergence between input logits and target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length

None

log_target

bool

Whether the target logits are already in log space.

False

Returns:

Type Description
<class 'torch.Tensor'>

torch.Tensor: Jefferys divergence, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) jefferys_div = masked_jefferys_divergence( ... input_logits, target_logits, attention_mask, log_target=False ... ) print(jefferys_div.shape) torch.Size([])

masked_kl_divergence

masked_kl_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = True) -> <class 'torch.Tensor'>

Compute KL divergence between input logits and target logits.

Parameters:

Name Type Description Default

input_logits

Tensor

Logits from the model, shape (batch_size, sequence_length, embedding_dim)

required

target_logits

Tensor

Logits from the target, shape (batch_size, sequence_length, embedding_dim)

required

attention_mask

Tensor | None

Optional attention mask, shape (batch_size, sequence_length)

None

log_target

bool

Whether the target logits are already in log space.

True

Returns:

Type Description
<class 'torch.Tensor'>

torch.Tensor: KL divergence, averaged over the batch and sequence length.

Example

input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) kl_div = masked_kl_divergence(input_logits, target_logits, attention_mask) print(kl_div.shape) torch.Size([])

Added in version v1.8.0.