distillation
Functions:
Name | Description |
---|---|
distillation_loss_factory |
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal |
distillation_loss_factory
¶
distillation_loss_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
distillation_layer_index: int | None,
alpha: float,
std_log_ratio_loss_weight: float,
input_embedding_similarity_loss_weight: float,
distillation_layer_cosine_distance_loss_weight: float,
distillation_layer_l2_distance_loss_weight: float = 0.0,
inverse_l2_power_law_catalyst_loss_weight: float = 0.0,
inverse_l2_power_law_characteristic_length: float = 1.0,
inverse_l2_power_law_exponent_of_decay: float = 1.0,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
|
int | None
|
The index of the decoder layer in the base model at which to perform distillation. When |
required |
|
float
|
The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. |
required |
|
float
|
The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength). |
required |
|
float
|
The weight of the loss component which aims to minimize the similarity of the input embeddings. |
required |
|
float
|
The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
required |
|
float
|
The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
0.0
|
|
float
|
The weight of the inverse l2 power law catalyst loss. This argument is experimental. |
0.0
|
|
float
|
The parameter modifying the distance between embeddings when computing the inverse l2
power law catalyst loss. This allows the user to account for the data dependent distances between embeddings allowing for the
loss to be O(1) when the the distances are approximately the same value as the scale. For embeddings this is recommended to be
the 5th percentile of the Voronoi l2 distances between embeddings. See |
1.0
|
|
float
|
The value to inverse exponentiate the computed distance by during the computation of the inverse l2 power law catalyst loss. This value must be positive. This argument is experimental. |
1.0
|
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |
Raises:
Type | Description |
---|---|
ValueError
|
If the exponent is non-negative. |
Warning
This API is experimental and subject to change: The distillation loss factory API is currently under development and is subject to change.