jigsaw
Functions:
Name | Description |
---|---|
jigsaw_loss_factory |
Create the jigsaw loss function to perform jigsaw training on a Stained Glass Transform causal language model using a teacher causal |
jigsaw_loss_factory
¶
jigsaw_loss_factory(
noisy_teacher_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
noisy_student_model: PeftNoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak[Any]
],
distillation_layer_index: int | None,
alpha: float,
std_log_ratio_loss_weight: float,
input_embedding_similarity_loss_weight: float,
distillation_layer_cosine_distance_loss_weight: float,
distillation_layer_l2_distance_loss_weight: float = 0.0,
distillation_protection_l2_loss_weight: float = 0.0,
apply_distillation_protection_loss_mask: bool = True,
clean_embedding_performance_on_student_model_loss_weight: float = 0.0,
inverse_l2_power_law_catalyst_loss_weight: float = 0.0,
inverse_l2_power_law_characteristic_length: float = 1.0,
inverse_l2_power_law_exponent_of_decay: float = 1.0,
) -> tuple[
Callable[[torch.Tensor], torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create the jigsaw loss function to perform jigsaw training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The teacher model containing both the causal language base model and the Stained Glass Transform. |
required |
|
PeftNoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]
|
The student model containing both causal language model with peft adapters and the Stained Glass Transform for Jigsaw Training. |
required |
|
int | None
|
The index of the decoder layer in the base model at which to perform distillation. When |
required |
|
float
|
The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. |
required |
|
float
|
The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength). |
required |
|
float
|
The weight of the loss component which aims to minimize the similarity of the input embeddings. |
required |
|
float
|
The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
required |
|
float
|
The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings. |
0.0
|
|
float
|
This loss is used to ensure that the base model does not learn to perform well on the Stained Glass Transform. |
0.0
|
|
bool
|
Whether to apply the |
True
|
|
float
|
The weight of the loss component used to ensure that clean embeddings don't perform well on the Peft Model. |
0.0
|
|
float
|
The weight of the inverse l2 power law catalyst loss. This argument is experimental. |
0.0
|
|
float
|
The parameter modifying the distance between embeddings when computing the inverse l2
power law catalyst loss. This allows the user to account for the data dependent distances between embeddings allowing for the
loss to be O(1) when the the distances are approximately the same value as the scale. For embeddings this is recommended to be
the 5th percentile of the Voronoi l2 distances between embeddings. See |
1.0
|
|
float
|
The value to inverse exponentiate the computed distance by during the computation of the inverse l2 power law catalyst loss. This value must be positive. This argument is experimental. |
1.0
|
Returns:
Type | Description |
---|---|
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |
Raises:
Type | Description |
---|---|
ValueError
|
If the exponent is non-negative. |
Warning
This API is experimental and subject to change: The jigsaw loss factory API is currently under development and is subject to change.