divergences
Functions:
Name | Description |
---|---|
masked_cross_entropy |
Compute cross-entropy loss between input logits and target logits. |
masked_jefferys_divergence |
Compute Jefferys divergence between input logits and target logits. |
masked_kl_divergence |
Compute KL divergence between input logits and target logits. |
masked_cross_entropy
¶
masked_cross_entropy(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None) -> <class 'torch.Tensor'>
Compute cross-entropy loss between input logits and target logits.
Applies softmax to the target logits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length) |
None
|
Returns:
Type | Description |
---|---|
<class 'torch.Tensor'>
|
Cross-entropy loss, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) cross_entropy_loss = masked_cross_entropy( ... input_logits, target_logits, attention_mask ... ) print(cross_entropy_loss.shape) torch.Size([])
Added in version v1.8.0.
masked_jefferys_divergence
¶
masked_jefferys_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = False) -> <class 'torch.Tensor'>
Compute Jefferys divergence between input logits and target logits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length |
None
|
|
bool
|
Whether the target logits are already in log space. |
False
|
Returns:
Type | Description |
---|---|
<class 'torch.Tensor'>
|
torch.Tensor: Jefferys divergence, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) jefferys_div = masked_jefferys_divergence( ... input_logits, target_logits, attention_mask, log_target=False ... ) print(jefferys_div.shape) torch.Size([])
masked_kl_divergence
¶
masked_kl_divergence(input_logits: Tensor, target_logits: Tensor, attention_mask: Tensor | None = None, log_target: bool = True) -> <class 'torch.Tensor'>
Compute KL divergence between input logits and target logits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
Logits from the model, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor
|
Logits from the target, shape (batch_size, sequence_length, embedding_dim) |
required |
|
Tensor | None
|
Optional attention mask, shape (batch_size, sequence_length) |
None
|
|
bool
|
Whether the target logits are already in log space. |
True
|
Returns:
Type | Description |
---|---|
<class 'torch.Tensor'>
|
torch.Tensor: KL divergence, averaged over the batch and sequence length. |
Example
input_logits = torch.randn(2, 5, 10) target_logits = torch.randn(2, 5, 10) attention_mask = torch.ones(2, 5, dtype=torch.bool) kl_div = masked_kl_divergence(input_logits, target_logits, attention_mask) print(kl_div.shape) torch.Size([])
Added in version v1.8.0.