Skip to content

jigsaw

Functions:

Name Description
compute_jigsaw_ce_losses

Compute the cross-entropy (CE) losses for Jigsaw training, including performance, protection, and guardrails losses.

compute_jigsaw_l2_losses

Compute the L2 losses for Jigsaw training at the distillation layer.

compute_masked_cross_entropy

Calculate the cross entropy loss between the output logits and target logits.

compute_masked_normalized_l2_loss

Calculate the l2 loss at the distillation layer.

jigsaw_loss_factory

Create the jigsaw loss function to perform jigsaw training on a Stained Glass Transform causal language model using a teacher causal

jigsaw_sgt_loss_factory

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal

compute_jigsaw_ce_losses

compute_jigsaw_ce_losses(
    activations: ActivationsDict,
    batch_size: int,
    sequence_length: int,
    vocab_size: int,
    attention_mask: Tensor,
    loss_mask: Tensor,
    losses: ComponentLossesDict,
    jigsaw_ce_protection_max_loss: float | None,
    jigsaw_ce_protection_loss_scaling_factor: float,
) -> ComponentLossesDict

Compute the cross-entropy (CE) losses for Jigsaw training, including performance, protection, and guardrails losses.

Parameters:

Name Type Description Default

activations

ActivationsDict

Dictionary containing model activations, including logits.

required

batch_size

int

The batch size.

required

sequence_length

int

The sequence length for a given batch.

required

vocab_size

int

The size of the vocabulary.

required

attention_mask

Tensor

The attention mask to apply to the loss.

required

loss_mask

Tensor

The mask to apply to the protection and guardrails losses.

required

losses

ComponentLossesDict

Dictionary to store computed loss components.

required

jigsaw_ce_protection_max_loss

float | None

Maximum value to cap the protection cross-entropy loss.

required

jigsaw_ce_protection_loss_scaling_factor

float

Scaling factor for the sigmoid cross-entropy protection loss.

required

Returns:

Name Type Description
dict ComponentLossesDict

The updated losses dictionary with computed CE loss components.

compute_jigsaw_l2_losses

compute_jigsaw_l2_losses(
    activations: ActivationsDict,
    super_batch_size: int,
    loss_mask: Tensor,
    losses: ComponentLossesDict,
) -> ComponentLossesDict

Compute the L2 losses for Jigsaw training at the distillation layer.

Parameters:

Name Type Description Default

activations

ActivationsDict

Dictionary containing model activations, including distillation layer embeddings.

required

super_batch_size

int

The total batch size (typically twice the actual batch size due to clean/noisy split).

required

loss_mask

Tensor

The mask to apply to the loss.

required

losses

ComponentLossesDict

Dictionary to store computed loss components.

required

Returns:

Name Type Description
dict ComponentLossesDict

The updated losses dictionary with computed L2 loss components.

compute_masked_cross_entropy

compute_masked_cross_entropy(
    output_logits: Tensor,
    target_labels: Tensor,
    batch_size: int,
    sequence_length: int,
    mask: Tensor,
    max_loss: float | None = None,
    ce_scaling_factor: float = 1.0,
) -> torch.Tensor

Calculate the cross entropy loss between the output logits and target logits.

The cross-entropy loss can produce aggressive (large) gradients. Supplying a max_loss value feeds the cross-entropy to a sigmoid that is scaled by max_loss.

Parameters:

Name Type Description Default

output_logits

Tensor

The output logits from the model.

required

target_labels

Tensor

The target labels.

required

batch_size

int

The batch size.

required

sequence_length

int

The sequence length for a given batch.

required

mask

Tensor

The mask to apply to the loss.

required

max_loss

float | None

The maximum value for the loss to cap the cross entropy loss.

None

ce_scaling_factor

float

A scaling factor for the sigmoid cross-entropy loss. Defaults to 1.0.

1.0

Returns:

Type Description
torch.Tensor

The cross entropy loss.

compute_masked_normalized_l2_loss

compute_masked_normalized_l2_loss(
    embeddings: Tensor,
    target_embeddings: Tensor,
    mask: Tensor,
) -> torch.Tensor

Calculate the l2 loss at the distillation layer.

Parameters:

Name Type Description Default

embeddings

Tensor

The activation embeddings output from the model at the distillation layer from student model.

required

target_embeddings

Tensor

The target activation embeddings at the distillation layer from teacher model.

required

mask

Tensor

The mask to apply to the loss.

required

Returns:

Type Description
torch.Tensor

The l2 loss.

jigsaw_loss_factory

Create the jigsaw loss function to perform jigsaw training on a Stained Glass Transform causal language model using a teacher causal language model.

Parameters:

Name Type Description Default

noisy_teacher_model

NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The teacher model containing both the causal language base model and the Stained Glass Transform.

required

noisy_student_model

PeftNoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The student model containing both causal language model with peft adapters and the Stained Glass Transform for Jigsaw Training.

required

distillation_layer_index

int | None

The index of the decoder layer in the base model at which to perform distillation. When None distillation is performed at the logits layer.

required

alpha

float

The interpolation factor between the distillation loss (maximizing model similarity) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength.

required

std_log_ratio_loss_weight

float

The weight of the loss component which aims to maximize the Stained Glass Transform's standard deviations (transformation strength).

required

input_embedding_similarity_loss_weight

float

The weight of the loss component which aims to minimize the similarity of the input embeddings.

required

distillation_layer_cosine_distance_loss_weight

float

The weight of the loss component which aims to maximize the similarity of the distillation layer embeddings.

required

distillation_layer_l2_distance_loss_weight

float

The weight of a subcomponent of the loss component which aims to maximize the similarity of the distillation layer embeddings.

0.0

distillation_protection_l2_loss_weight

float

This loss is used to ensure that the base model does not learn to perform well on the Stained Glass Transform.

0.0

apply_distillation_protection_loss_mask

bool

Whether to apply the loss_mask to the distillation protection loss.

True

clean_embedding_performance_on_student_model_loss_weight

float

The weight of the loss component used to ensure that clean embeddings don't perform well on the Peft Model.

0.0

inverse_l2_power_law_catalyst_loss_weight

float

The weight of the inverse l2 power law catalyst loss. This argument is experimental.

0.0

inverse_l2_power_law_characteristic_length

float

The parameter modifying the distance between embeddings when computing the inverse l2 power law catalyst loss. This allows the user to account for the data dependent distances between embeddings allowing for the loss to be O(1) when the the distances are approximately the same value as the scale. For embeddings this is recommended to be the 5th percentile of the Voronoi l2 distances between embeddings. See scripts/compute_percentile_voronoi_l2_distance.py for an example of how to compute this value. This argument is experimental.

1.0

inverse_l2_power_law_exponent_of_decay

float

The value to inverse exponentiate the computed distance by during the computation of the inverse l2 power law catalyst loss. This value must be positive. This argument is experimental.

1.0

Returns:

Type Description
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.

Raises:

Type Description
ValueError

If the exponent is non-negative.

Warning

This API is experimental and subject to change: The jigsaw loss factory API is currently under development and is subject to change.

jigsaw_sgt_loss_factory

Create a loss function to perform distillation training on a Stained Glass Transform causal language model using a teacher causal language model.

Parameters:

Name Type Description Default

noisy_teacher_model

NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The teacher model containing both the causal language base model and the Stained Glass Transform.

required

noisy_student_model

PeftNoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak[Any]]

The student model containing both causal language model with peft adapters and the Stained Glass Transform for Jigsaw Training.

required

absolute_cosine_similarity_penalty_weight

float

The weight of the loss component which aims to minimize the absolute value of the similarity of the clean and transformed input embeddings. Increasing this makes the optimization prefer higher transformation strength.

0.0

jigsaw_ce_performance_loss_weight

float

The weight of the loss component which aims to minimize the difference between the clean logits from the teacher model and the noisy logits from the student model. Optimizing this loss will ensure that the LLM model with Jigsaw adapters performs well with the Jigsaw-SGT.

0.0

jigsaw_ce_protection_loss_weight

float

The weight of the loss component which aims to maximize the difference between the clean logits from the teacher model and the noisy logits from the teacher model. Optimizing this loss will ensure that the base LLM model does not perform well with the Jigsaw-SGT.

0.0

jigsaw_ce_protection_max_loss

float | None

The maximum value for the loss to cap the jigsaw protection cross entropy loss. Defaults to None, which means no cap is applied.

None

jigsaw_ce_protection_loss_scaling_factor

float

A scaling factor for the sigmoid jigsaw protection cross-entropy loss. Defaults to 1.0.

1.0

mutual_information_loss_weight

float

The weight of the loss component which aims to maximize the mutual information between the clean and transformed input embeddings. Increasing this makes the optimization prefer higher transformation strength.

0.0

median_norm_penalty_weight

float

The weight of the loss component which aims to regularize the mean-shifted input embeddings to have norms near median_norm.

0.0

jigsaw_clean_embeddings_guardrails_loss_weight

float

The weight of the loss component which aims to maximize the difference between the clean logits from the teacher model and the clean logits from the student model. Optimizing this loss will ensure that the LLM model with Jigsaw adapters performs poorly with clean embeddings.

0.0

distillation_layer_index

int | None

The index of the decoder layer in the base model at which to perform distillation. When None distillation is performed at the logits layer.

None

jigsaw_l2_performance_loss_weight

float

The weight of the loss component which aims to minimize the difference between the clean activations from the teacher model and the noisy activations from the student model at distillation layer. Optimizing this loss will ensure that the LLM model with Jigsaw adapters performs well with the Jigsaw-SGT.

0.0

jigsaw_l2_protection_loss_weight

float

The weight of the loss component which aims to maximize the difference between the clean activations from the teacher model and the noisy activations from the teacher model at distillation layer. Optimizing this loss will ensure that the base LLM model does not perform well with the Jigsaw-SGT.

0.0

combine_noise_masks_for_mutual_information_sub_batches

bool

Whether to only use the features in the noise masks for the two sub-batches which are simultaneously unmasked when computing the mutual information loss.

True

Returns:

Type Description
tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A tuple of 3 functions: the loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.

Raises:

Type Description
ValueError

If the distillation layer index is None and the jigsaw l2 performance or protection loss weights are not zero.

ValueError

If the noisy teacher or student model is truncated and offloaded and any of the jigsaw cross entropy losses are not zero.

Warning

This API is experimental and subject to change: The jigsaw sgt loss factory API is currently under development and is subject to change.