jigsaw
Functions:
| Name | Description | 
|---|---|
| compute_jigsaw_ce_losses | Compute the cross-entropy (CE) losses for Jigsaw training, including performance, protection, and guardrails losses. | 
| compute_jigsaw_l2_losses | Compute the L2 losses for Jigsaw training at the distillation layer. | 
| compute_masked_cross_entropy | Calculate the cross entropy loss between the output logits and target logits. | 
| compute_masked_normalized_l2_loss | Calculate the l2 loss at the distillation layer. | 
| jigsaw_loss_factory | Create the jigsaw loss function to perform jigsaw training on a Stained Glass Transform causal language model using a teacher causal | 
| jigsaw_sgt_loss_factory | Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal | 
compute_jigsaw_ce_losses(
    activations: ActivationsDict,
    batch_size: int,
    sequence_length: int,
    vocab_size: int,
    attention_mask: Tensor,
    loss_mask: Tensor,
    losses: ComponentLossesDict,
    jigsaw_ce_protection_max_loss: float | None,
    jigsaw_ce_protection_loss_scaling_factor: float,
) -> ComponentLossesDict
Compute the cross-entropy (CE) losses for Jigsaw training, including performance, protection, and guardrails losses.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | ActivationsDict | Dictionary containing model activations, including logits. | required | 
|                    | int | The batch size. | required | 
|                    | int | The sequence length for a given batch. | required | 
|                    | int | The size of the vocabulary. | required | 
|                    | Tensor | The attention mask to apply to the loss. | required | 
|                    | Tensor | The mask to apply to the protection and guardrails losses. | required | 
|                    | ComponentLossesDict | Dictionary to store computed loss components. | required | 
|                    | float | None | Maximum value to cap the protection cross-entropy loss. | required | 
|                    | float | Scaling factor for the sigmoid cross-entropy protection loss. | required | 
Returns:
| Name | Type | Description | 
|---|---|---|
| dict | ComponentLossesDict | The updated losses dictionary with computed CE loss components. | 
compute_jigsaw_l2_losses(
    activations: ActivationsDict,
    super_batch_size: int,
    loss_mask: Tensor,
    losses: ComponentLossesDict,
) -> ComponentLossesDict
Compute the L2 losses for Jigsaw training at the distillation layer.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | ActivationsDict | Dictionary containing model activations, including distillation layer embeddings. | required | 
|                    | int | The total batch size (typically twice the actual batch size due to clean/noisy split). | required | 
|                    | Tensor | The mask to apply to the loss. | required | 
|                    | ComponentLossesDict | Dictionary to store computed loss components. | required | 
Returns:
| Name | Type | Description | 
|---|---|---|
| dict | ComponentLossesDict | The updated losses dictionary with computed L2 loss components. | 
compute_masked_cross_entropy(
    output_logits: Tensor,
    target_labels: Tensor,
    batch_size: int,
    sequence_length: int,
    mask: Tensor,
    max_loss: float | None = None,
    ce_scaling_factor: float = 1.0,
) -> torch.Tensor
Calculate the cross entropy loss between the output logits and target logits.
The cross-entropy loss can produce aggressive (large) gradients. Supplying a max_loss value feeds
the cross-entropy to a sigmoid that is scaled by max_loss.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Tensor | The output logits from the model. | required | 
|                    | Tensor | The target labels. | required | 
|                    | int | The batch size. | required | 
|                    | int | The sequence length for a given batch. | required | 
|                    | Tensor | The mask to apply to the loss. | required | 
|                    | float | None | The maximum value for the loss to cap the cross entropy loss. | None | 
|                    | float | A scaling factor for the sigmoid cross-entropy loss. Defaults to 1.0. | 1.0 | 
Returns:
| Type | Description | 
|---|---|
| torch.Tensor | The cross entropy loss. | 
compute_masked_normalized_l2_loss(
    embeddings: Tensor,
    target_embeddings: Tensor,
    mask: Tensor,
) -> torch.Tensor
Calculate the l2 loss at the distillation layer.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Tensor | The activation embeddings output from the model at the distillation layer from student model. | required | 
|                    | Tensor | The target activation embeddings at the distillation layer from teacher model. | required | 
|                    | Tensor | The mask to apply to the loss. | required | 
Returns:
| Type | Description | 
|---|---|
| torch.Tensor | The l2 loss. | 
jigsaw_loss_factory(
    noisy_teacher_model: NoiseMaskedNoisyTransformerModel[
        CausalModelT, ..., TransformerCloak[Any]
    ],
    noisy_student_model: PeftNoiseMaskedNoisyTransformerModel[
        CausalModelT, ..., TransformerCloak[Any]
    ],
    distillation_layer_index: int | None,
    alpha: float,
    std_log_ratio_loss_weight: float,
    input_embedding_similarity_loss_weight: float,
    distillation_layer_cosine_distance_loss_weight: float,
    distillation_layer_l2_distance_loss_weight: float = 0.0,
    distillation_protection_l2_loss_weight: float = 0.0,
    apply_distillation_protection_loss_mask: bool = True,
    clean_embedding_performance_on_student_model_loss_weight: float = 0.0,
    inverse_l2_power_law_catalyst_loss_weight: float = 0.0,
    inverse_l2_power_law_characteristic_length: float = 1.0,
    inverse_l2_power_law_exponent_of_decay: float = 1.0,
) -> tuple[
    Callable[[torch.Tensor], torch.Tensor],
    Callable[[], ComponentLossesDict],
    Callable[[], HyperparametersDict],
]
Create the jigsaw loss function to perform jigsaw training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]] | The teacher model containing both the causal language base model and the Stained Glass Transform. | required | 
|                    | PeftNoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]] | The student model containing both causal language model with peft adapters and the Stained Glass Transform for Jigsaw Training. | required | 
|                    | int | None | The index of the decoder layer in the base model at which to perform distillation. When  | required | 
|                    | float | The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength. | required | 
|                    | float | The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength). | required | 
|                    | float | The weight of the loss component which aims to minimize the similarity of the input embeddings. | required | 
|                    | float | The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings. | required | 
|                    | float | The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings. | 0.0 | 
|                    | float | This loss is used to ensure that the base model does not learn to perform well on the Stained Glass Transform. | 0.0 | 
|                    | bool | Whether to apply the  | True | 
|                    | float | The weight of the loss component used to ensure that clean embeddings don't perform well on the Peft Model. | 0.0 | 
|                    | float | The weight of the inverse l2 power law catalyst loss. This argument is experimental. | 0.0 | 
|                    | float | The parameter modifying the distance between embeddings when computing the inverse l2
power law catalyst loss. This allows the user to account for the data dependent distances between embeddings allowing for the
loss to be O(1) when the the distances are approximately the same value as the scale. For embeddings this is recommended to be
the 5th percentile of the Voronoi l2 distances between embeddings. See  | 1.0 | 
|                    | float | The value to inverse exponentiate the computed distance by during the computation of the inverse l2 power law catalyst loss. This value must be positive. This argument is experimental. | 1.0 | 
Returns:
| Type | Description | 
|---|---|
| tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]] | A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the | 
| tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]] | hyperparameters. These functions may be called at most once each after a forward pass through both models. | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If the exponent is non-negative. | 
Warning
This API is experimental and subject to change: The jigsaw loss factory API is currently under development and is subject to change.
jigsaw_sgt_loss_factory(
    noisy_teacher_model: NoiseMaskedNoisyTransformerModel[
        CausalModelT, ..., TransformerCloak[Any]
    ],
    noisy_student_model: PeftNoiseMaskedNoisyTransformerModel[
        CausalModelT, ..., TransformerCloak[Any]
    ],
    absolute_cosine_similarity_penalty_weight: float = 0.0,
    jigsaw_ce_performance_loss_weight: float = 0.0,
    jigsaw_ce_protection_loss_weight: float = 0.0,
    jigsaw_ce_protection_max_loss: float | None = None,
    jigsaw_ce_protection_loss_scaling_factor: float = 1.0,
    mutual_information_loss_weight: float = 0.0,
    median_norm_penalty_weight: float = 0.0,
    jigsaw_clean_embeddings_guardrails_loss_weight: float = 0.0,
    distillation_layer_index: int | None = None,
    jigsaw_l2_performance_loss_weight: float = 0.0,
    jigsaw_l2_protection_loss_weight: float = 0.0,
    combine_noise_masks_for_mutual_information_sub_batches: bool = True,
) -> tuple[
    Callable[[torch.Tensor], torch.Tensor],
    Callable[[], ComponentLossesDict],
    Callable[[], HyperparametersDict],
]
Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]] | The teacher model containing both the causal language base model and the Stained Glass Transform. | required | 
|                    | PeftNoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]] | The student model containing both causal language model with peft adapters and the Stained Glass Transform for Jigsaw Training. | required | 
|                    | float | The weight of the loss component which aims to minimize the absolute value of the similarity of the clean and transformed input embeddings. Increasing this makes the optimization prefer higher transformation strength. | 0.0 | 
|                    | float | The weight of the loss component which aims to minimize the difference between the clean logits from the teacher model and the noisy logits from the student model. Optimizing this loss will ensure that the LLM model with Jigsaw adapters performs well with the Jigsaw-SGT. | 0.0 | 
|                    | float | The weight of the loss component which aims to maximize the difference between the clean logits from the teacher model and the noisy logits from the teacher model. Optimizing this loss will ensure that the base LLM model does not perform well with the Jigsaw-SGT. | 0.0 | 
|                    | float | None | The maximum value for the loss to cap the jigsaw protection cross entropy loss. Defaults to  | None | 
|                    | float | A scaling factor for the sigmoid jigsaw protection cross-entropy loss. Defaults to 1.0. | 1.0 | 
|                    | float | The weight of the loss component which aims to maximize the mutual information between the clean and transformed input embeddings. Increasing this makes the optimization prefer higher transformation strength. | 0.0 | 
|                    | float | The weight of the loss component which aims to regularize the mean-shifted input embeddings to have
norms near  | 0.0 | 
|                    | float | The weight of the loss component which aims to maximize the difference between the clean logits from the teacher model and the clean logits from the student model. Optimizing this loss will ensure that the LLM model with Jigsaw adapters performs poorly with clean embeddings. | 0.0 | 
|                    | int | None | The index of the decoder layer in the base model at which to perform distillation. When  | None | 
|                    | float | The weight of the loss component which aims to minimize the difference between the clean activations from the teacher model and the noisy activations from the student model at distillation layer. Optimizing this loss will ensure that the LLM model with Jigsaw adapters performs well with the Jigsaw-SGT. | 0.0 | 
|                    | float | The weight of the loss component which aims to maximize the difference between the clean activations from the teacher model and the noisy activations from the teacher model at distillation layer. Optimizing this loss will ensure that the base LLM model does not perform well with the Jigsaw-SGT. | 0.0 | 
|                    | bool | Whether to only use the features in the noise masks for the two sub-batches which are simultaneously unmasked when computing the mutual information loss. | True | 
Returns:
| Type | Description | 
|---|---|
| tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]] | A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the | 
| tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]] | hyperparameters. These functions may be called at most once each after a forward pass through both models. | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If the distillation layer index is  | 
| ValueError | If the noisy teacher or student model is truncated and offloaded and any of the jigsaw cross entropy losses are not zero. | 
Warning
This API is experimental and subject to change: The jigsaw sgt loss factory API is currently under development and is subject to change.