Skip to content

distillation

Functions:

Name Description
distillation_loss_factory

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal

mutual_information_distillation_loss_factory

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal

cross_entropy_distillation_factory

cross_entropy_distillation_factory(
    noisy_model: NoiseMaskedNoisyTransformerModel[
        CausalModelT, ..., TransformerCloak[Any]
    ],
    distillation_cross_entropy_loss_weight: float,
    absolute_cosine_similarity_penalty_weight: float,
    median_norm_penalty_weight: float,
) -> tuple[
    Callable[[torch.Tensor], torch.Tensor],
    Callable[[], ComponentLossesDict],
    Callable[[], HyperparametersDict],
]

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.

Parameters:

Name Type Description Default

noisy_model

NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The model containing both the causal language base model and the Stained Glass Transform.

required

distillation_cross_entropy_loss_weight

float

The weight of the loss component which aims to minimize the difference between the clean and transformed logits.

required

absolute_cosine_similarity_penalty_weight

float

The weight of the loss component which aims to minimize the absolute value of the similarity of the clean and transformed input embeddings.

required

median_norm_penalty_weight

float

The weight of the loss component which aims to regularize the mean-shifted input embeddings to have norms near median_norm.

required

Returns:

Type Description
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.

Warning

This API is experimental and subject to change: The distillation loss factory API is currently under development and is subject to change.

distillation_loss_factory

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.

Parameters:

Name Type Description Default

noisy_model

NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The model containing both the causal language base model and the Stained Glass Transform.

required

distillation_layer_index

int | None

The index of the decoder layer in the base model at which to perform distillation. When None distillation is performed at the logits layer.

required

alpha

float

The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength.

required

std_log_ratio_loss_weight

float

The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength).

required

input_embedding_similarity_loss_weight

float

The weight of the loss component which aims to minimize the similarity of the input embeddings.

required

distillation_layer_cosine_distance_loss_weight

float

The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings.

required

distillation_layer_l2_distance_loss_weight

float

The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings.

0.0

inverse_l2_power_law_catalyst_loss_weight

float

The weight of the inverse l2 power law catalyst loss. This argument is experimental.

0.0

inverse_l2_power_law_characteristic_length

float

The parameter modifying the distance between embeddings when computing the inverse l2 power law catalyst loss. This allows the user to account for the data dependent distances between embeddings allowing for the loss to be O(1) when the the distances are approximately the same value as the scale. For embeddings this is recommended to be the 5th percentile of the Voronoi l2 distances between embeddings. See scripts/compute_percentile_voronoi_l2_distance.py for an example of how to compute this value. This argument is experimental.

1.0

inverse_l2_power_law_exponent_of_decay

float

The value to inverse exponentiate the computed distance by during the computation of the inverse l2 power law catalyst loss. This value must be positive. This argument is experimental.

1.0

Returns:

Type Description
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.

Raises:

Type Description
ValueError

If the exponent is non-negative.

mutual_information_cross_entropy_distillation_factory

mutual_information_cross_entropy_distillation_factory(
    noisy_model: NoiseMaskedNoisyTransformerModel[
        CausalModelT, ..., TransformerCloak[Any]
    ],
    mutual_information_loss_weight: float,
    distillation_cross_entropy_loss_weight: float,
    absolute_cosine_similarity_penalty_weight: float,
    median_norm_penalty_weight: float,
    combine_noise_masks_for_mutual_information_sub_batches: bool = True,
) -> tuple[
    Callable[[torch.Tensor], torch.Tensor],
    Callable[[], ComponentLossesDict],
    Callable[[], HyperparametersDict],
]

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.

Note

Using this loss requires the micro batch size to be even.

Parameters:

Name Type Description Default

noisy_model

NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The model containing both the causal language base model and the Stained Glass Transform.

required

mutual_information_loss_weight

float

The weight of the mutual information loss. Increasing this makes the optimization prefer higher transformation strength.

required

distillation_cross_entropy_loss_weight

float

The weight of the loss component which aims to minimize the difference between the clean and transformed logits.

required

absolute_cosine_similarity_penalty_weight

float

The weight of the loss component which aims to minimize the absolute value of the similarity of the clean and transformed input embeddings.

required

median_norm_penalty_weight

float

The weight of the loss component which aims to regularize the mean-shifted input embeddings to have norms near median_norm.

required

combine_noise_masks_for_mutual_information_sub_batches

bool

Whether to only use the features in the noise masks for the two sub-batches which are simultaneously unmasked when computing the mutual information loss.

True

Returns:

Type Description
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.

Warning

This API is experimental and subject to change: The distillation loss factory API is currently under development and is subject to change.

mutual_information_distillation_loss_factory

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.

Note: The minimum batch size for the mutual information distillation loss is 2.

Parameters:

Name Type Description Default

noisy_model

NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The model containing both the causal language base model and the Stained Glass Transform.

required

distillation_layer_index

int | None

The index of the decoder layer in the base model at which to perform distillation. When None distillation is performed at the logits layer.

required

alpha

float

The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength.

required

distillation_layer_l2_distance_loss_weight

float

The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings.

required

mutual_information_loss_weight

float

The weight of the mutual information loss. Increasing this makes the optimization prefer higher transformation strength.

required

inverse_l2_power_law_catalyst_loss_weight

float

Experimental parameter.

0.0

inverse_l2_power_law_characteristic_length

float

Experimental parameter.

1.0

inverse_l2_power_law_exponent_of_decay

float

Experimental parameter.

1.0

input_embedding_similarity_loss_weight

float

The weight of the loss component which aims to minimize the similarity of the input embeddings.

0.0

combine_noise_masks_for_mutual_information_sub_batches

bool

Whether to only use the features in the noise masks for the two sub-batches which are simultaneously unmasked when computing the mutual information loss.

True

Returns:

Type Description
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.

Raises:

Type Description
ValueError

If the exponent is non-negative.

Added in version v0.135.0. To support distillation training with mutual information loss.

Warning

This API is experimental and subject to change: The distillation loss factory API is currently under development and is subject to change.