distillation
Functions:
Name | Description |
---|---|
distillation_loss_factory |
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal |
mutual_information_distillation_loss_factory |
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal |
cross_entropy_distillation_factory
¶
cross_entropy_distillation_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
distillation_cross_entropy_loss_weight: float,
absolute_cosine_similarity_penalty_weight: float,
median_norm_penalty_weight: float,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
|
float
|
The weight of the loss component which aims to minimize the difference between the clean and transformed logits. |
required |
|
float
|
The weight of the loss component which aims to minimize the absolute value of the similarity of the clean and transformed input embeddings. |
required |
|
float
|
The weight of the loss component which aims to regularize the mean-shifted input embeddings to have
norms near |
required |
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |
Warning
This API is experimental and subject to change: The distillation loss factory API is currently under development and is subject to change.
distillation_loss_factory
¶
distillation_loss_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
distillation_layer_index: int | None,
alpha: float,
std_log_ratio_loss_weight: float,
input_embedding_similarity_loss_weight: float,
distillation_layer_cosine_distance_loss_weight: float,
distillation_layer_l2_distance_loss_weight: float = 0.0,
inverse_l2_power_law_catalyst_loss_weight: float = 0.0,
inverse_l2_power_law_characteristic_length: float = 1.0,
inverse_l2_power_law_exponent_of_decay: float = 1.0,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
|
int | None
|
The index of the decoder layer in the base model at which to perform distillation. When |
required |
|
float
|
The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. |
required |
|
float
|
The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength). |
required |
|
float
|
The weight of the loss component which aims to minimize the similarity of the input embeddings. |
required |
|
float
|
The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
required |
|
float
|
The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
0.0
|
|
float
|
The weight of the inverse l2 power law catalyst loss. This argument is experimental. |
0.0
|
|
float
|
The parameter modifying the distance between embeddings when computing the inverse l2
power law catalyst loss. This allows the user to account for the data dependent distances between embeddings allowing for the
loss to be O(1) when the the distances are approximately the same value as the scale. For embeddings this is recommended to be
the 5th percentile of the Voronoi l2 distances between embeddings. See |
1.0
|
|
float
|
The value to inverse exponentiate the computed distance by during the computation of the inverse l2 power law catalyst loss. This value must be positive. This argument is experimental. |
1.0
|
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |
Raises:
Type | Description |
---|---|
ValueError
|
If the exponent is non-negative. |
mutual_information_cross_entropy_distillation_factory
¶
mutual_information_cross_entropy_distillation_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
mutual_information_loss_weight: float,
distillation_cross_entropy_loss_weight: float,
absolute_cosine_similarity_penalty_weight: float,
median_norm_penalty_weight: float,
combine_noise_masks_for_mutual_information_sub_batches: bool = True,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Note
Using this loss requires the micro batch size to be even.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
|
float
|
The weight of the mutual information loss. Increasing this makes the optimization prefer higher transformation strength. |
required |
|
float
|
The weight of the loss component which aims to minimize the difference between the clean and transformed logits. |
required |
|
float
|
The weight of the loss component which aims to minimize the absolute value of the similarity of the clean and transformed input embeddings. |
required |
|
float
|
The weight of the loss component which aims to regularize the mean-shifted input embeddings to have
norms near |
required |
|
bool
|
Whether to only use the features in the noise masks for the two sub-batches which are simultaneously unmasked when computing the mutual information loss. |
True
|
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |
Warning
This API is experimental and subject to change: The distillation loss factory API is currently under development and is subject to change.
mutual_information_distillation_loss_factory
¶
mutual_information_distillation_loss_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
distillation_layer_index: int | None,
alpha: float,
distillation_layer_l2_distance_loss_weight: float,
mutual_information_loss_weight: float,
inverse_l2_power_law_catalyst_loss_weight: float = 0.0,
inverse_l2_power_law_characteristic_length: float = 1.0,
inverse_l2_power_law_exponent_of_decay: float = 1.0,
input_embedding_similarity_loss_weight: float = 0.0,
combine_noise_masks_for_mutual_information_sub_batches: bool = True,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Note: The minimum batch size for the mutual information distillation loss is 2.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The model containing both the causal language base model and the Stained Glass Transform. |
required |
|
int | None
|
The index of the decoder layer in the base model at which to perform distillation. When |
required |
|
float
|
The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. |
required |
|
float
|
The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
required |
|
float
|
The weight of the mutual information loss. Increasing this makes the optimization prefer higher transformation strength. |
required |
|
float
|
Experimental parameter. |
0.0
|
|
float
|
Experimental parameter. |
1.0
|
|
float
|
Experimental parameter. |
1.0
|
|
float
|
The weight of the loss component which aims to minimize the similarity of the input embeddings. |
0.0
|
|
bool
|
Whether to only use the features in the noise masks for the two sub-batches which are simultaneously unmasked when computing the mutual information loss. |
True
|
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |
Raises:
Type | Description |
---|---|
ValueError
|
If the exponent is non-negative. |
Added in version v0.135.0. To support distillation training with mutual information loss.
Warning
This API is experimental and subject to change: The distillation loss factory API is currently under development and is subject to change.