metrics
Module for metrics calculation.
Functions:
| Name | Description |
|---|---|
get_noise_metrics_factory |
Create a function that returns norm and cosine similarity metrics of the transform components applied during the most recent forward |
get_top_k_support_mask |
Return a boolean mask for the union of top-k indices from noisy and clean logits. |
get_utility_metrics_factory |
Create a function that returns next-token utility divergence metrics between clean and noisy logits from the latest forward pass. |
percentage_changed_ids |
Compute the percentage of token ids that differ between |
get_noise_metrics_factory
¶
get_noise_metrics_factory(
noise_layer: BaseNoiseLayer, dim: int = -1
) -> Callable[[], dict[str, torch.Tensor]]
Create a function that returns norm and cosine similarity metrics of the transform components applied during the most recent forward pass.
The statistics computed include
- mean_domination_percent: mean_norm / (mean_norm + unshifted_noise_norm)
- mean_domination_ratio: mean_norm / unshifted_noise_norm
- mean_cosine_similarity: cosine_similarity(input, mean)
- noise_domination_percent: mean_shifted_noise_norm / (input_norm + mean_shifted_noise_norm)
- noise_domination_ratio: mean_shifted_noise_norm / input_norm
- noise_cosine_similarity: cosine_similarity(input, mean_shifted_noise)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
BaseNoiseLayer
|
The |
required |
|
int
|
The dimension to compute the norms and cosine similarities over. Must be negative. |
-1
|
Returns:
| Type | Description |
|---|---|
Callable[[], dict[str, torch.Tensor]]
|
A function that returns the norm and cosine similarity statistics of the transform components applied during the most recent |
Callable[[], dict[str, torch.Tensor]]
|
forward pass. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Examples:
>>> import torch
>>> from torch import nn
>>> from stainedglass_core import model as sg_model, noise_layer as sg_noise_layer
>>> base_model = nn.Conv2d(3, 4, kernel_size=2)
>>> noisy_model = sg_model.NoisyModel(
... sg_noise_layer.CloakNoiseLayer1,
... base_model,
... target_parameter="input",
... )
>>> get_noise_metrics = get_noise_metrics_factory(noisy_model.noise_layer, dim=-3)
>>> input = torch.ones(4, 3, 2, 2)
>>> noise_mask = torch.tensor([[True, False], [False, True]])
>>> output = noisy_model(input, noise_mask=noise_mask)
>>> noise_metrics = get_noise_metrics()
>>> noise_metrics
{'input_norm': tensor(...), 'mean_norm': tensor(...), 'unshifted_noise_norm': tensor(...), 'mean_shifted_noise_norm': tensor(...), 'mean_domination_percent': tensor(...), 'mean_domination_ratio': tensor(...), 'mean_cosine_similarity': tensor(...), 'noise_domination_percent': tensor(...), 'noise_domination_ratio': tensor(...), 'noise_cosine_similarity': tensor(...)}
>>> {value.shape for value in noise_metrics.values()}
{torch.Size([8])}
Added in version v3.38.0.
get_top_k_support_mask
¶
get_top_k_support_mask(
noisy_logits: Tensor,
clean_logits: Tensor,
num_logits: int,
) -> torch.Tensor
Return a boolean mask for the union of top-k indices from noisy and clean logits.
This function computes the top-k indices (along the vocabulary dimension) from both noisy_logits and clean_logits. It then constructs a support mask that indicates which indices are present in either set of top-k values. The result is a boolean tensor with the same shape as the logits, where True entries correspond to indices included in the union of the supports, and False otherwise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The logits produced from SGT. |
required |
|
Tensor
|
The logits produced from Base Model excluding SGT. |
required |
|
int
|
The number of top logits to include in the support from each tensor. |
required |
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
A boolean tensor of the same shape as noisy_logits where True marks |
torch.Tensor
|
the union of top-k indices from both noisy_logits and clean_logits. |
Added in version v3.38.0.
get_utility_metrics_factory
¶
get_utility_metrics_factory(
noisy_model: NoiseMaskedNoisyTransformerModel[
CausalModelT, ..., TransformerCloak
],
num_logits: int | None = None,
max_batch_size: int | None = None,
max_sequence_length: int | None = None,
) -> Callable[..., dict[str, torch.Tensor]]
Create a function that returns next-token utility divergence metrics between clean and noisy logits from the latest forward pass.
Hooks are attached to noisy_model to capture its attention_mask input and the underlying base model's output.logits. The
factory assumes the noisy model produces a super-batch whose first half corresponds to clean logits and second half to noisy logits.
For every captured forward pass, the returned callable computes four f-divergences between the per-token clean and noisy
distributions, reduces them to a per-sequence value weighted by the number of attended tokens, and finally averages across the
batch.
The divergences computed are
- Jeffreys divergence (
stainedglass_core.loss.divergences.jefferys_divergence) - Jensen-Shannon divergence (
stainedglass_core.loss.divergences.jensen_shannon_divergence) - Total variation (
stainedglass_core.loss.divergences.total_variation) - Squared Hellinger distance (
stainedglass_core.loss.divergences.squared_hellinger_distance)
When num_logits is provided, each divergence is restricted to the union of the top-num_logits indices selected from the noisy
and clean logits at every position (see get_top_k_support_mask). When num_logits is None, the divergences are computed over
the full vocabulary.
To bound peak memory, the logits and attention mask are split along the batch and sequence-length dimensions via
stainedglass_core.utils.torch.split_2d using max_batch_size and max_sequence_length. Per-split divergence values are
reassembled and averaged using attention-mask token counts as weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
NoiseMaskedNoisyTransformerModel[CausalModelT, ..., TransformerCloak]
|
The noisy transformer model used both for hook attachment and as the source of |
required |
|
int | None
|
The |
None
|
|
int | None
|
Maximum batch size per chunk when splitting the logits/attention mask for memory-efficient processing.
If |
None
|
|
int | None
|
Maximum sequence length per chunk when splitting the logits/attention mask for memory-efficient
processing. If |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Callable[..., dict[str, torch.Tensor]]
|
A hook-wrapped callable that, when invoked by the training loop, returns a dict mapping each metric name to a scalar |
|
tensor |
the attention-mask-weighted mean divergence across the batch
|
|
Callable[..., dict[str, torch.Tensor]]
|
{
"next_token_utility/jefferys": |
|
Callable[..., dict[str, torch.Tensor]]
|
} |
Added in version v3.38.0.
percentage_changed_ids
¶
percentage_changed_ids(
input_ids: Tensor,
reconstructed_ids: Tensor,
noise_mask: Tensor,
) -> torch.Tensor
Compute the percentage of token ids that differ between input_ids and reconstructed_ids.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The original token ids. |
required |
|
Tensor
|
The token ids reconstructed from the transformed embeddings of |
required |
|
Tensor
|
The mask that selects the elements of |
required |
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
The percentage of token ids that differ between |
torch.Tensor
|
|
Examples: