Skip to content

means

Functions:

Name Description
means_norm_penalty

Penalize the absolute average difference between the norms of the mean-shifted

means_norm_penalty

means_norm_penalty(
    clean_input_embeddings: Tensor,
    means: Tensor,
    noise_mask: Tensor,
    penalty_value: float,
) -> torch.Tensor

Penalize the absolute average difference between the norms of the mean-shifted input embeddings and a penalty value.

For batch elements with noise_mask[i, :] == 0 the mean is set to 0.0. So these will contribute to the loss only in denominator.

Parameters:

Name Type Description Default

clean_input_embeddings

Tensor

The clean input embeddings tensor. (B, S, E)

required

means

Tensor

The means tensor to shift the input embeddings by. (B, S, E)

required

noise_mask

Tensor

The mask to apply to the reduction. (B, S)

required

penalty_value

float

The value to penalize the norms against.

required

Returns:

Type Description
torch.Tensor

The mean absolute difference between the norms of the mean-shifted input embeddings

torch.Tensor

and the penalty value.