Skip to content

distillation

distillation_loss_factory

distillation_loss_factory(noisy_model: NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]], distillation_layer_index: int, alpha: float, std_log_ratio_loss_weight: float, input_embedding_similarity_loss_weight: float, distillation_layer_distance_loss_weight: float) -> tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.

Parameters:

Name Type Description Default
noisy_model NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The model containing both the causal language base model and the Stained Glass Transform.

required
distillation_layer_index int

The index of the decoder layer in the base model at which to perform distillation.

required
alpha float

The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength.

required
std_log_ratio_loss_weight float

The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength).

required
input_embedding_similarity_loss_weight float

The weight of the loss component which aims to minimize the similarity of the input embeddings.

required
distillation_layer_distance_loss_weight float

The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings.

required

Returns:

Type Description
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]

A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.