entropy
Functions:
Name | Description |
---|---|
gaussian_mixture_entropy |
Compute the entropy of a Gaussian mixture on a mini batch via Monte Carlo integration. |
mutual_information_sub_batch |
Approximate the mutual information via Monte Carlo with sub-batch sampling. |
square_mahalanobis_distance |
Calculate the square Mahalanobis Distance between the features of x and y given the diagonal tensor of standard deviations. |
gaussian_mixture_entropy
¶
gaussian_mixture_entropy(
clean_embeddings: Tensor,
transformed_embeddings: Tensor,
means: Tensor,
stds: Tensor,
mask: Tensor,
losses: ComponentLossesDict | None = None,
) -> torch.Tensor
Compute the entropy of a Gaussian mixture on a mini batch via Monte Carlo integration.
If losses is not None, the following losses are added: - gaussian_mixture_mahalanobis_distance_loss - gaussian_mixture_log_constants_and_determinant_loss
Note: The transformed embeddings should be a different batch than the clean embeddings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
The untransformed embeddings. |
required |
|
Tensor
|
The stained glass transformed embeddings. |
required |
|
Tensor
|
The means tensor. |
required |
|
Tensor
|
The standard deviations tensor. |
required |
|
Tensor
|
The mask to apply to the reduction. |
required |
|
ComponentLossesDict | None
|
The dictionary of component losses. If provided, the loss will be added to the losses dictionary. |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
The computed Gaussian mixture entropy. |
Warning
This API is experimental and subject to change: The entropy component of the mutual information loss is still under research.
mutual_information_sub_batch
¶
mutual_information_sub_batch(
means_tensor: Tensor,
stds_tensor: Tensor,
noise_mask: Tensor,
clean_input_embeddings: Tensor,
noisy_input_embeddings: Tensor,
combine_noise_masks_for_mutual_information_sub_batches: bool = True,
losses: ComponentLossesDict | None = None,
) -> torch.Tensor
Approximate the mutual information via Monte Carlo with sub-batch sampling.
Assumes that the noisy_input_embeddings
are generated by adding a Gaussian
to the clean_input_embeddings
. The loss is computed as follows:
$ L(X_b, Y_b; X_b') = \frac{1}{B}\sum_{\ell', i} \log(|\Sigma_i^{-1}\Sigma_{\ell'}|)+ \| \y_i - x_{\ell'} - \mu_{\ell'} \|_{\Sigma_{\ell'}^{-1}}^2$
Where X_b and X_b' are sub batches of clean_input_embeddings and Y_b is the sub batch
of noisy_input_embeddings corresponding to X_b. The \mu and \Sigma are the corresponding
sub-batches from means_tensor
and stds_tensor
, respectively.
If losses is not None
, the following losses are added:
- gaussian_mixture_mahalanobis_distance_loss
- gaussian_mixture_log_constants_and_determinant_loss
- gaussian_mixture_entropy_loss
- mutual_information_loss
- std_log_loss
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
The means tensor for the Gaussian mixture. (B, S, E) |
required |
|
Tensor
|
The standard deviations tensor for the Gaussian mixture. (B, S, E) |
required |
|
Tensor
|
The mask to apply to the reduction. (B, S) |
required |
|
Tensor
|
The untransformed embeddings. (B, S, E) |
required |
|
Tensor
|
The stained glass transformed embeddings. (B, S, E) |
required |
|
bool
|
Whether to use the intersection of the noise masks
from both sub-batches. Default |
True
|
|
ComponentLossesDict | None
|
The dictionary of component losses. If provided, the loss will be added to the losses dictionary. |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
The computed mutual information loss. |
Raises:
Type | Description |
---|---|
ValueError
|
If the batch size is less than 2 or if the batch size is not even or if the shapes of the input embeddings do not match. |
square_mahalanobis_distance
¶
Calculate the square Mahalanobis Distance between the features of x and y given the diagonal tensor of standard deviations.
Notes
Adds a normalization factor of 0.5 to the square Mahalanobis distance calculation for the variance to align with the Gaussian distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
The first tensor. |
required |
|
Tensor
|
The second tensor. |
required |
|
Tensor
|
The tensor of standard deviations. |
required |
|
Tensor
|
The mask to apply to the reduction. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The computed square Mahalanobis distance between x and y. |
Warning
This API is experimental and subject to change: The entropy component of the mutual information loss is still under research.