Skip to content

parameterizations

Module for noise parameterizations.

Classes:

Name Description
BoundedMeanParameterization

Bound a cloak's spatial mean field to [-bound, +bound] via tanh.

CloakStandardDeviationParameterization

Parameterizes rhos tensors (on the domain of all real numbers) as standard deviations tensors (on the open domain of min_scale to

DirectStandardDeviationParameterization

A direct parameterization of rhos tensors as standard deviations tensors (clamped into the closed domain of min_scale to

ScaledStandardDeviationParameterization

Defines the common structures necessary to parameterize rhos tensors (on the domain of all real numbers) as standard deviation

StandardDeviationParameterization

Defines the interface for the reparameterization of rhos tensors (on the domain of all real numbers) as standard deviation tensors

BoundedMeanParameterization

Bases: Module

Bound a cloak's spatial mean field to [-bound, +bound] via tanh.

Computes bound * tanh((rhos + init_shift) / bound). With init_shift = 0.0 the output is centered at zero when rhos = 0, matching the typical near-zero initialization of a prediction-head's raw output. init_shift is provided so that callers whose upstream estimator biases rhos away from zero can recenter the output at init without coupling the parameterization to any specific estimator convention.

The map rhos -> bound * tanh((rhos + init_shift) / bound) is smooth and monotone. The gradient saturates as |rhos| grows large (consistent with tanh): far from the origin the parameterization becomes near-flat. The image is [-bound, +bound]; in finite precision torch.tanh can return exactly +/-1, so the output may equal the bounds at extreme inputs rather than only approach them.

Why bound the mean at all? Without a bound, the mean estimator's raw output is unbounded; under the LLM's RMSNorm-saturated regime the gradient on the mean magnitude can vanish, and the estimator drifts unboundedly on bf16 + AdamW noise. The tanh bound restores a hard structural limit on the mean magnitude.

Mirrors CloakStandardDeviationParameterization on the mean side: a tanh-based parameterization with a structural bound on the output.

Parameters:

Name Type Description Default

bound

float

Magnitude bound on the parameterized mean field. The output lies in [-bound, +bound]. Must be positive.

required

init_shift

float

A scalar added to rhos inside forward. Used to cancel an upstream estimator's negative bias so the initialization satisfies forward(0) ~ 0. Defaults to 0.0.

0.0

Added in version v3.41.0. A tanh-based mean-side parameterization mirroring `CloakStandardDeviationParameterization` on the std side. Smoothly bounds the cloak's spatial mean field within `(-bound, +bound)` to prevent unbounded drift under bf16 + RMSNorm-saturated gradients.

Methods:

Name Description
forward

Apply bound * tanh((rhos + init_shift) / bound).

reset_parameters

Reinitialize parameters and buffers.

forward

forward(rhos: Tensor) -> torch.Tensor

Apply bound * tanh((rhos + init_shift) / bound).

Parameters:

Name Type Description Default

rhos

Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A bounded mean tensor of the same shape as rhos with values in (-bound, +bound).

reset_parameters

reset_parameters() -> None

Reinitialize parameters and buffers.

This parameterization has no learned parameters or buffers, so this is a no-op. It exists for API compatibility with StandardDeviationParameterization.

CloakStandardDeviationParameterization

Bases: ScaledStandardDeviationParameterization

Parameterizes rhos tensors (on the domain of all real numbers) as standard deviations tensors (on the open domain of min_scale to max_scale).

min_scale is strictly less than max_scale, and both must be nonnegative real numbers.

Parameters:

Name Type Description Default

scale

tuple[float, float]

The asymptotic minimum and maximum values of the parameterized standard deviations.

(0.0001, 2.0)

shallow

float

A temperature-like parameter which controls the spread of the parameterization. Controls both the magnitude of parameterized standard deviations and their rate of change with respect to rhos. shallow > 1.0 yields a more gradual rate of change and higher standard deviations for rhos < 0.0 (typical for initialization) and lower standard deviations for rhos > 0.0. shallow < 1.0 yields the opposite, approaching a step function as shallow goes to 0.0 (not trainable).

1.0

init_shift

float

A scalar added to rhos inside forward before the tanh. Used when the upstream estimator biases its raw output away from zero and the parameterization should still reach the same point on its sigmoid for that biased input. Specifically, forward(-init_shift) produces the midpoint (min_scale + max_scale) / 2. Defaults to 0.0, which gives the historical formula (1 + tanh(rhos / shallow)) / 2 * (max_scale - min_scale) + min_scale.

0.0

Methods:

Name Description
__init__
forward

Apply the parameterization to a rhos tensor.

inverse

Apply the inverse parameterization to a standard deviation tensor.

reset_parameters

Reinitialize parameters and buffers.

Attributes:

Name Type Description
max_scale Tensor

The upper bound of the standard deviations.

min_scale Tensor

The lower bound of the standard deviations.

max_scale property

max_scale: Tensor

The upper bound of the standard deviations.

min_scale property

min_scale: Tensor

The lower bound of the standard deviations.

__init__

__init__(
    scale: tuple[float, float] = (0.0001, 2.0),
    shallow: float = 1.0,
    init_shift: float = 0.0,
) -> None

Changed in version v3.42.0: Added the `init_shift` constructor argument so the parameterization can absorb an upstream estimator's input bias.

forward

forward(rhos: Tensor) -> torch.Tensor

Apply the parameterization to a rhos tensor.

With a nonzero init_shift, the input is shifted by init_shift before the tanh, so that forward(-init_shift) lands at the midpoint of the output interval.

Parameters:

Name Type Description Default

rhos

Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse

inverse(std: Tensor) -> torch.Tensor

Apply the inverse parameterization to a standard deviation tensor.

With a nonzero init_shift, the init_shift term is subtracted from the result so the round-trip inverse(forward(rhos)) equals rhos.

Parameters:

Name Type Description Default

std

Tensor

A standard deviation tensor.

required

Returns:

Type Description
torch.Tensor

A rhos tensor.

reset_parameters

reset_parameters() -> None

Reinitialize parameters and buffers.

This method is useful for initializing tensors created on the meta device.

Raises:

Type Description
ValueError

If shallow is not a scalar.

ValueError

If shallow is nonpositive.

DirectStandardDeviationParameterization

Bases: ScaledStandardDeviationParameterization

A direct parameterization of rhos tensors as standard deviations tensors (clamped into the closed domain of min_scale to max_scale).

Parameters:

Name Type Description Default

scale

tuple[float, float]

The minimum and maximum values of the parameterized standard deviations.

(0.0001, 2.0)

init_shift

float

A scalar added to rhos inside forward before the absolute value and clamp. Use this when the upstream estimator biases its raw output away from zero (e.g. subtracts a constant before forwarding to the parameterization) and direct-std learning should still initialize near min_scale instead of near |init_shift|. Defaults to 0.0, which gives the historical behavior.

0.0

Methods:

Name Description
__init__
forward

Clamp a rhos tensor into the closed domain of min_scale to max_scale.

inverse

The direct standard deviation parameterization is not bijective and so is non-invertible.

reset_parameters

Reinitialize parameters and buffers.

Attributes:

Name Type Description
max_scale Tensor

The upper bound of the standard deviations.

min_scale Tensor

The lower bound of the standard deviations.

max_scale property

max_scale: Tensor

The upper bound of the standard deviations.

min_scale property

min_scale: Tensor

The lower bound of the standard deviations.

__init__

__init__(
    scale: tuple[float, float] = (0.0001, 2.0),
    init_shift: float = 0.0,
) -> None

Changed in version v3.41.0: Added the `init_shift` constructor argument for use with estimators that bias their raw output away from zero.

forward

forward(rhos: Tensor) -> torch.Tensor

Clamp a rhos tensor into the closed domain of min_scale to max_scale.

With a nonzero init_shift, the input is first shifted by init_shift before the absolute-value and clamp.

Parameters:

Name Type Description Default

rhos

Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse

inverse(std: Tensor) -> NoReturn

The direct standard deviation parameterization is not bijective and so is non-invertible.

Parameters:

Name Type Description Default

std

Tensor

A standard deviation tensor.

required

Raises:

Type Description
TypeError

If called.

reset_parameters

reset_parameters() -> None

Reinitialize parameters and buffers.

This method is useful for initializing tensors created on the meta device.

Raises:

Type Description
ValueError

If scale is not a 2-vector.

ValueError

If min_scale is negative.

ValueError

If min_scale is not strictly less than max_scale.

ScaledStandardDeviationParameterization

Bases: StandardDeviationParameterization

Defines the common structures necessary to parameterize rhos tensors (on the domain of all real numbers) as standard deviation tensors (on the domain of min_scale to max_scale).

min_scale is strictly less than max_scale, and both must be nonnegative real numbers.

Parameters:

Name Type Description Default

scale

tuple[float, float]

The minimum and maximum values of the parameterized standard deviations.

(0.0001, 2.0)

Methods:

Name Description
forward

Apply the parameterization to a rhos tensor.

inverse

Apply the inverse parameterization to a standard deviation tensor.

reset_parameters

Reinitialize parameters and buffers.

Attributes:

Name Type Description
max_scale Tensor

The upper bound of the standard deviations.

min_scale Tensor

The lower bound of the standard deviations.

max_scale property

max_scale: Tensor

The upper bound of the standard deviations.

min_scale property

min_scale: Tensor

The lower bound of the standard deviations.

forward abstractmethod

forward(rhos: Tensor) -> torch.Tensor

Apply the parameterization to a rhos tensor.

Parameters:

Name Type Description Default

rhos

Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse abstractmethod

inverse(std: Tensor) -> torch.Tensor

Apply the inverse parameterization to a standard deviation tensor.

Parameters:

Name Type Description Default

std

Tensor

A standard deviation tensor.

required

Returns:

Type Description
torch.Tensor

A rhos tensor.

reset_parameters

reset_parameters() -> None

Reinitialize parameters and buffers.

This method is useful for initializing tensors created on the meta device.

Raises:

Type Description
ValueError

If scale is not a 2-vector.

ValueError

If min_scale is negative.

ValueError

If min_scale is not strictly less than max_scale.

StandardDeviationParameterization

Bases: Module, ABC

Defines the interface for the reparameterization of rhos tensors (on the domain of all real numbers) as standard deviation tensors (on the domain of nonnegative real numbers) of the applied transformation.

Rhos can be learned directly or estimated as the output of a neural network.

The derivative of this parameterization defines the rate of change of the standard deviations with respect rhos. When combined with a so-called "noise loss", used to penalize the standard deviations from deviating from the target distribution, and a task loss (i.e. classification, token prediction), we define a complete training objective and loss landscape for transform layers.

Methods:

Name Description
forward

Apply the parameterization to a rhos tensor.

inverse

Apply the inverse parameterization to a standard deviation tensor.

reset_parameters

Reinitialize parameters and buffers.

forward abstractmethod

forward(rhos: Tensor) -> torch.Tensor

Apply the parameterization to a rhos tensor.

Parameters:

Name Type Description Default

rhos

Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse abstractmethod

inverse(std: Tensor) -> torch.Tensor

Apply the inverse parameterization to a standard deviation tensor.

Parameters:

Name Type Description Default

std

Tensor

A standard deviation tensor.

required

Returns:

Type Description
torch.Tensor

A rhos tensor.

reset_parameters abstractmethod

reset_parameters() -> None

Reinitialize parameters and buffers.

This method is useful for initializing tensors created on the meta device.