distillation
distillation_loss_factory
¶
distillation_loss_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
distillation_layer_index: int,
alpha: float,
std_log_ratio_loss_weight: float,
input_embedding_similarity_loss_weight: float,
distillation_layer_distance_loss_weight: float,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], dict[str, torch.Tensor]],
Callable[[], dict[str, Any]],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
noisy_model |
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
distillation_layer_index |
int
|
The index of the decoder layer in the base model at which to perform distillation. |
required |
alpha |
float
|
The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. |
required |
std_log_ratio_loss_weight |
float
|
The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength). |
required |
input_embedding_similarity_loss_weight |
float
|
The weight of the loss component which aims to minimize the similarity of the input embeddings. |
required |
distillation_layer_distance_loss_weight |
float
|
The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
required |
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |