means
Functions:
Name | Description |
---|---|
means_norm_penalty |
Penalize the absolute average difference between the norms of the mean-shifted |
means_norm_penalty
¶
means_norm_penalty(
clean_input_embeddings: Tensor,
means: Tensor,
noise_mask: Tensor,
penalty_value: float,
) -> torch.Tensor
Penalize the absolute average difference between the norms of the mean-shifted input embeddings and a penalty value.
For batch elements with noise_mask[i, :] == 0 the mean is set to 0.0. So these will contribute to the loss only in denominator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
The clean input embeddings tensor. (B, S, E) |
required |
|
Tensor
|
The means tensor to shift the input embeddings by. (B, S, E) |
required |
|
Tensor
|
The mask to apply to the reduction. (B, S) |
required |
|
float
|
The value to penalize the norms against. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The mean absolute difference between the norms of the mean-shifted input embeddings |
torch.Tensor
|
and the penalty value. |