distillation
Functions:
Name | Description |
---|---|
distillation_loss_factory |
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal |
distillation_loss_factory
¶
distillation_loss_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
distillation_layer_index: int | None,
alpha: float,
std_log_ratio_loss_weight: float,
input_embedding_similarity_loss_weight: float,
distillation_layer_cosine_distance_loss_weight: float,
distillation_layer_l2_distance_loss_weight: float = 0.0,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
|
int | None
|
The index of the decoder layer in the base model at which to perform distillation. When |
required |
|
float
|
The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. |
required |
|
float
|
The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength). |
required |
|
float
|
The weight of the loss component which aims to minimize the similarity of the input embeddings. |
required |
|
float
|
The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
required |
|
float
|
The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
0.0
|
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |