Skip to content

metrics

Module for metrics calculation.

Functions:

Name Description
get_noise_metrics_factory

Create a function that returns norm and cosine similarity metrics of the transform components applied during the most recent forward

get_top_k_support_mask

Return a boolean mask for the union of top-k indices from noisy and clean logits.

get_utility_metrics_factory

Create a function that returns next-token utility divergence metrics between clean and noisy logits from the latest forward pass.

percentage_changed_ids

Compute the percentage of token ids that differ between input_ids and reconstructed_ids.

get_noise_metrics_factory

get_noise_metrics_factory(
    noise_layer: BaseNoiseLayer, dim: int = -1
) -> Callable[[], dict[str, torch.Tensor]]

Create a function that returns norm and cosine similarity metrics of the transform components applied during the most recent forward pass.

The statistics computed include
  • mean_domination_percent: mean_norm / (mean_norm + unshifted_noise_norm)
  • mean_domination_ratio: mean_norm / unshifted_noise_norm
  • mean_cosine_similarity: cosine_similarity(input, mean)
  • noise_domination_percent: mean_shifted_noise_norm / (input_norm + mean_shifted_noise_norm)
  • noise_domination_ratio: mean_shifted_noise_norm / input_norm
  • noise_cosine_similarity: cosine_similarity(input, mean_shifted_noise)

Parameters:

Name Type Description Default

noise_layer

BaseNoiseLayer

The BaseNoiseLayer to compute the statistics for.

required

dim

int

The dimension to compute the norms and cosine similarities over. Must be negative.

-1

Returns:

Type Description
Callable[[], dict[str, torch.Tensor]]

A function that returns the norm and cosine similarity statistics of the transform components applied during the most recent

Callable[[], dict[str, torch.Tensor]]

forward pass.

Raises:

Type Description
ValueError

If dim is not negative.

Examples:

>>> import torch
>>> from torch import nn
>>> from stainedglass_core import model as sg_model, noise_layer as sg_noise_layer
>>> base_model = nn.Conv2d(3, 4, kernel_size=2)
>>> noisy_model = sg_model.NoisyModel(
...     sg_noise_layer.CloakNoiseLayer1,
...     base_model,
...     target_parameter="input",
... )
>>> get_noise_metrics = get_noise_metrics_factory(noisy_model.noise_layer, dim=-3)
>>> input = torch.ones(4, 3, 2, 2)
>>> noise_mask = torch.tensor([[True, False], [False, True]])
>>> output = noisy_model(input, noise_mask=noise_mask)
>>> noise_metrics = get_noise_metrics()
>>> noise_metrics
{'input_norm': tensor(...), 'mean_norm': tensor(...), 'unshifted_noise_norm': tensor(...), 'mean_shifted_noise_norm': tensor(...), 'mean_domination_percent': tensor(...), 'mean_domination_ratio': tensor(...), 'mean_cosine_similarity': tensor(...), 'noise_domination_percent': tensor(...), 'noise_domination_ratio': tensor(...), 'noise_cosine_similarity': tensor(...)}
>>> {value.shape for value in noise_metrics.values()}
{torch.Size([8])}

Added in version v3.38.0.

get_top_k_support_mask

get_top_k_support_mask(
    noisy_logits: Tensor,
    clean_logits: Tensor,
    num_logits: int,
) -> torch.Tensor

Return a boolean mask for the union of top-k indices from noisy and clean logits.

This function computes the top-k indices (along the vocabulary dimension) from both noisy_logits and clean_logits. It then constructs a support mask that indicates which indices are present in either set of top-k values. The result is a boolean tensor with the same shape as the logits, where True entries correspond to indices included in the union of the supports, and False otherwise.

Parameters:

Name Type Description Default

noisy_logits

Tensor

The logits produced from SGT.

required

clean_logits

Tensor

The logits produced from Base Model excluding SGT.

required

num_logits

int

The number of top logits to include in the support from each tensor.

required

Returns:

Type Description
torch.Tensor

A boolean tensor of the same shape as noisy_logits where True marks

torch.Tensor

the union of top-k indices from both noisy_logits and clean_logits.

Added in version v3.38.0.

get_utility_metrics_factory

get_utility_metrics_factory(
    noisy_model: NoiseMaskedNoisyTransformerModel[
        CausalModelT, ..., TransformerCloak
    ],
    num_logits: int | None = None,
    max_batch_size: int | None = None,
    max_sequence_length: int | None = None,
) -> Callable[..., dict[str, torch.Tensor]]

Create a function that returns next-token utility divergence metrics between clean and noisy logits from the latest forward pass.

Hooks are attached to noisy_model to capture its attention_mask input and the underlying base model's output.logits. The factory assumes the noisy model produces a super-batch whose first half corresponds to clean logits and second half to noisy logits. For every captured forward pass, the returned callable computes four f-divergences between the per-token clean and noisy distributions, reduces them to a per-sequence value weighted by the number of attended tokens, and finally averages across the batch.

The divergences computed are
  • Jeffreys divergence (stainedglass_core.loss.divergences.jefferys_divergence)
  • Jensen-Shannon divergence (stainedglass_core.loss.divergences.jensen_shannon_divergence)
  • Total variation (stainedglass_core.loss.divergences.total_variation)
  • Squared Hellinger distance (stainedglass_core.loss.divergences.squared_hellinger_distance)

When num_logits is provided, each divergence is restricted to the union of the top-num_logits indices selected from the noisy and clean logits at every position (see get_top_k_support_mask). When num_logits is None, the divergences are computed over the full vocabulary.

To bound peak memory, the logits and attention mask are split along the batch and sequence-length dimensions via stainedglass_core.utils.torch.split_2d using max_batch_size and max_sequence_length. Per-split divergence values are reassembled and averaged using attention-mask token counts as weights.

Parameters:

Name Type Description Default

noisy_model

NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak]

The noisy transformer model used both for hook attachment and as the source of attention_mask (from its forward arguments) and logits (from its base model's output).

required

num_logits

int | None

The k used to form the per-position top-k support mask. If None, divergences are computed over the full vocabulary without any support masking.

None

max_batch_size

int | None

Maximum batch size per chunk when splitting the logits/attention mask for memory-efficient processing. If None, the batch dimension is not split.

None

max_sequence_length

int | None

Maximum sequence length per chunk when splitting the logits/attention mask for memory-efficient processing. If None, the sequence dimension is not split.

None

Returns:

Name Type Description
Callable[..., dict[str, torch.Tensor]]

A hook-wrapped callable that, when invoked by the training loop, returns a dict mapping each metric name to a scalar

tensor the attention-mask-weighted mean divergence across the batch
Callable[..., dict[str, torch.Tensor]]

{ "next_token_utility/jefferys": , "next_token_utility/jensen_shannon": , "next_token_utility/total_variation": , "next_token_utility/squared_hellinger_distance": ,

Callable[..., dict[str, torch.Tensor]]

}

Added in version v3.38.0.

percentage_changed_ids

percentage_changed_ids(
    input_ids: Tensor,
    reconstructed_ids: Tensor,
    noise_mask: Tensor,
) -> torch.Tensor

Compute the percentage of token ids that differ between input_ids and reconstructed_ids.

Parameters:

Name Type Description Default

input_ids

Tensor

The original token ids.

required

reconstructed_ids

Tensor

The token ids reconstructed from the transformed embeddings of input_ids.

required

noise_mask

Tensor

The mask that selects the elements of input_ids that were transformed. The percentage changed is only computed over the elements selected by this mask.

required

Returns:

Type Description
torch.Tensor

The percentage of token ids that differ between input_ids and reconstructed_ids, only considering the elements selected by

torch.Tensor

noise_mask.

Examples:

>>> input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> reconstructed_ids = torch.tensor([[1, 2, 3], [1, 2, 6]])
>>> noise_mask = torch.tensor([[True, False, True], [True, True, True]])
>>> percentage_changed_ids(input_ids, reconstructed_ids, noise_mask)
tensor([0.0000, 0.6667])