entropy
Module for entropy-related loss functions, particularly for mutual information estimation.
Functions:
| Name | Description |
|---|---|
gaussian_entropy_log_loss |
Compute the per-element negative differential entropy of a Gaussian, averaged over positions. |
gaussian_mixture_entropy |
Compute the entropy of a Gaussian mixture on a mini batch via Monte Carlo integration. |
mutual_information_sub_batch |
Approximate the mutual information via Monte Carlo with sub-batch sampling. |
square_mahalanobis_distance |
Calculate the square Mahalanobis Distance between the features of x and y given the diagonal tensor of standard deviations. |
gaussian_entropy_log_loss
¶
gaussian_entropy_log_loss(
stds: Tensor, noise_mask: Tensor | None = None
) -> torch.Tensor
Compute the per-element negative differential entropy of a Gaussian, averaged over positions.
Computes mean(-log(stds + eps)) - 0.5 * log(2 * pi * e). The constant offset places the loss in nats — values above zero mean the
averaged Gaussian entropy is below the unit-variance reference, values below zero mean it is above. Minimizing this loss pushes
stds upward at every active position simultaneously, raising the noise distribution's per-element entropy.
This is intentionally distinct from
masked_negative_log_mean, which computes -log(mean(stds)): the log lives
outside the mean there. By Jensen's inequality, -log(E[stds]) <= E[-log(stds)], so gaussian_entropy_log_loss is the stricter
constraint — a single position with large stds cannot mask many other positions whose stds are low. Use
masked_negative_log_mean when only the mean noise budget matters; use gaussian_entropy_log_loss when every active position
needs an entropy floor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Standard deviation tensor as produced by a noise layer's |
required |
|
Tensor | None
|
Optional boolean mask selecting active positions. |
None
|
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
Scalar loss tensor on the same device and dtype as |
Example
stds = torch.full((4,), 1.0 / (2.0 * torch.pi * torch.e) ** 0.5) torch.allclose(gaussian_entropy_log_loss(stds), torch.zeros(()), atol=1e-6) True
Added in version v3.41.0. Per-element negative differential entropy of a Gaussian noise distribution. Useful as a per-position entropy floor on cloak standard deviations.
gaussian_mixture_entropy
¶
gaussian_mixture_entropy(
clean_embeddings: Tensor,
transformed_embeddings: Tensor,
means: Tensor,
stds: Tensor,
mask: Tensor,
losses: ComponentLossesDict | None = None,
safety_factor: float = 100.0,
) -> torch.Tensor
Compute the entropy of a Gaussian mixture on a mini batch via Monte Carlo integration.
If losses is not None, the following losses are added: - gaussian_mixture_mahalanobis_distance_loss - gaussian_mixture_log_constants_and_determinant_loss
Note: The transformed embeddings should be a different batch than the clean embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The untransformed embeddings. |
required |
|
Tensor
|
The stained glass transformed embeddings. |
required |
|
Tensor
|
The means tensor. |
required |
|
Tensor
|
The standard deviations tensor. |
required |
|
Tensor
|
The mask to apply to the reduction. |
required |
|
ComponentLossesDict | None
|
The dictionary of component losses. If provided, the loss will be added to the losses dictionary. |
None
|
|
float
|
The safety factor for the division in the Mahalanobis distance calculation. This is calculated exactly as
|
100.0
|
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
The computed Gaussian mixture entropy. |
Warning
This API is experimental and subject to change: The entropy component of the mutual information loss is still under research.
mutual_information_sub_batch
¶
mutual_information_sub_batch(
means_tensor: Tensor,
stds_tensor: Tensor,
noise_mask: Tensor,
clean_input_embeddings: Tensor,
noisy_input_embeddings: Tensor,
combine_noise_masks_for_mutual_information_sub_batches: bool = True,
losses: ComponentLossesDict | None = None,
safety_factor: float = 100.0,
) -> torch.Tensor
Approximate the mutual information via Monte Carlo with sub-batch sampling.
Assumes that the noisy_input_embeddings are generated by adding a Gaussian
to the clean_input_embeddings. The loss is computed as follows:
$ L(X_b, Y_b; X_b') = \frac{1}{B}\sum_{\ell', i} \log(|\Sigma_i^{-1}\Sigma_{\ell'}|)+ \| y_i - x_{\ell'} - \mu_{\ell'} \|_{\Sigma_{\ell'}^{-1}}^2$
Where X_b and X_b' are sub batches of clean_input_embeddings and Y_b is the sub batch
of noisy_input_embeddings corresponding to X_b. The \mu and \Sigma are the corresponding
sub-batches from means_tensor and stds_tensor, respectively.
If losses is not None, the following losses are added:
- gaussian_mixture_mahalanobis_distance_loss
- gaussian_mixture_log_constants_and_determinant_loss
- gaussian_mixture_entropy_loss
- mutual_information_loss
- std_log_loss
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The means tensor for the Gaussian mixture. (B, S, E) |
required |
|
Tensor
|
The standard deviations tensor for the Gaussian mixture. (B, S, E) |
required |
|
Tensor
|
The mask to apply to the reduction. (B, S) |
required |
|
Tensor
|
The untransformed embeddings. (B, S, E) |
required |
|
Tensor
|
The stained glass transformed embeddings. (B, S, E) |
required |
|
bool
|
Whether to use the intersection of the noise masks
from both sub-batches. Default |
True
|
|
ComponentLossesDict | None
|
The dictionary of component losses. If provided, the loss will be added to the losses dictionary. |
None
|
|
float
|
The safety factor for the division in the Mahalanobis distance calculation. This is calculated exactly as
|
100.0
|
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
The computed mutual information loss. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the batch size is less than 2 or if the batch size is not even or if the shapes of the input embeddings do not match. |
square_mahalanobis_distance
¶
square_mahalanobis_distance(
x: Tensor,
y: Tensor,
stds: Tensor,
mask: Tensor,
safety_factor: float = 100.0,
) -> torch.Tensor
Calculate the square Mahalanobis Distance between the features of x and y given the diagonal tensor of standard deviations.
Notes
Adds a normalization factor of 0.5 to the square Mahalanobis distance calculation for the variance to align with the Gaussian distribution. Adds a small epsilon (in the relative sense) to the covariance matrix to avoid division by zero. The epsilon is 10 times machine epsilon times the smallest non-zero element of the covariance matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The first tensor. |
required |
|
Tensor
|
The second tensor. |
required |
|
Tensor
|
The tensor of standard deviations. |
required |
|
Tensor
|
The mask to apply to the reduction. |
required |
|
float
|
The safety factor for the division. This is calculated exactly as |
100.0
|
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
The computed square Mahalanobis distance between x and y. |
Warning
This API is experimental and subject to change: The entropy component of the mutual information loss is still under research.