distillation
distillation_loss_factory
¶
distillation_loss_factory(noisy_model: NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]], distillation_layer_index: int, alpha: float, std_log_ratio_loss_weight: float, input_embedding_similarity_loss_weight: float, distillation_layer_distance_loss_weight: float) -> tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
noisy_model |
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
distillation_layer_index |
int
|
The index of the decoder layer in the base model at which to perform distillation. |
required |
alpha |
float
|
The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. |
required |
std_log_ratio_loss_weight |
float
|
The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength). |
required |
input_embedding_similarity_loss_weight |
float
|
The weight of the loss component which aims to minimize the similarity of the input embeddings. |
required |
distillation_layer_distance_loss_weight |
float
|
The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
required |
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], dict[str, torch.Tensor]], Callable[[], dict[str, Any]]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |