Skip to content

parameterizations

CloakStandardDeviationParameterization

Bases: ScaledStandardDeviationParameterization

A parameterization of rhos tensors (on the domain of all real numbers) as standard deviations tensors (on the open domain of min_scale to max_scale).

min_scale is strictly less than max_scale, and both must be nonnegative real numbers.

Added in version 0.11.0.

max_scale property

max_scale: Tensor

The upper bound of the standard deviations.

min_scale property

min_scale: Tensor

The lower bound of the standard deviations.

__init__

__init__(scale: tuple[float, float] | Tensor = (0.0001, 2.0), shallow: float | Tensor = 1.0) -> None

Construct a layer to perform the reparameterization.

Parameters:

Name Type Description Default
scale tuple[float, float] | Tensor

The asymptotic minimum and maximum values of the parameterized standard deviations.

(0.0001, 2.0)
shallow float | Tensor

A temperature-like parameter which controls the spread of the parameterization. Controls both the magnitude of parameterized standard deviations and their rate of change with respect to rhos. shallow > 1.0 yields a more gradual rate of change and higher standard deviations for rhos < 0.0 (typical for initialization) and lower standard deviations for rhos > 0.0. shallow < 1.0 yields the opposite, approaching a step function as shallow goes to 0.0 (not trainable).

1.0

Raises:

Type Description
ValueError

If shallow is not a scalar.

ValueError

If shallow is nonpositive.

forward

forward(rhos: Tensor) -> torch.Tensor

Apply the parameterization to a rhos tensor.

Parameters:

Name Type Description Default
rhos Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse

inverse(std: Tensor) -> torch.Tensor

Apply the inverse parameterization to a standard deviation tensor.

Parameters:

Name Type Description Default
std Tensor

A standard deviation tensor.

required

Returns:

Type Description
torch.Tensor

A rhos tensor.

DirectStandardDeviationParameterization

Bases: ScaledStandardDeviationParameterization

A direct parameterization of rhos tensors as standard deviations tensors (clamped into the closed domain of min_scale to max_scale).

max_scale property

max_scale: Tensor

The upper bound of the standard deviations.

min_scale property

min_scale: Tensor

The lower bound of the standard deviations.

__init__

__init__(scale: tuple[float, float] | Tensor = (0.0001, 2.0)) -> None

Construct a layer to perform the reparameterization.

Parameters:

Name Type Description Default
scale tuple[float, float] | Tensor

The minimum and maximum values of the parameterized standard deviations.

(0.0001, 2.0)

Raises:

Type Description
ValueError

If scale is not a 2-vector.

ValueError

If min_scale is negative.

ValueError

If min_scale is not strictly less than max_scale.

forward

forward(rhos: Tensor) -> torch.Tensor

Clamp a rhos tensor into the closed domain of min_scale to max_scale.

Parameters:

Name Type Description Default
rhos Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse

inverse(std: Tensor) -> NoReturn

The direct standard deviation parameterization is not bijective and so is non-invertible.

Parameters:

Name Type Description Default
std Tensor

A standard deviation tensor.

required

Raises:

Type Description
TypeError

If called.

ScaledStandardDeviationParameterization

Bases: StandardDeviationParameterization

Defines the common structures necessary to parameterize rhos tensors (on the domain of all real numbers) as standard deviation tensors (on the domain of min_scale to max_scale).

min_scale is strictly less than max_scale, and both must be nonnegative real numbers.

max_scale property

max_scale: Tensor

The upper bound of the standard deviations.

min_scale property

min_scale: Tensor

The lower bound of the standard deviations.

__init__

__init__(scale: tuple[float, float] | Tensor = (0.0001, 2.0)) -> None

Construct a layer to perform the reparameterization.

Parameters:

Name Type Description Default
scale tuple[float, float] | Tensor

The minimum and maximum values of the parameterized standard deviations.

(0.0001, 2.0)

Raises:

Type Description
ValueError

If scale is not a 2-vector.

ValueError

If min_scale is negative.

ValueError

If min_scale is not strictly less than max_scale.

forward abstractmethod

forward(rhos: Tensor) -> torch.Tensor

Apply the parameterization to a rhos tensor.

Parameters:

Name Type Description Default
rhos Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse abstractmethod

inverse(std: Tensor) -> torch.Tensor

Apply the inverse parameterization to a standard deviation tensor.

Parameters:

Name Type Description Default
std Tensor

A standard deviation tensor.

required

Returns:

Type Description
torch.Tensor

A rhos tensor.

StandardDeviationParameterization

Bases: Module, ABC

Defines the interface for the reparameterization of rhos tensors (on the domain of all real numbers) as standard deviation tensors (on the domain of nonnegative real numbers) of the applied transformation.

Rhos can be learned directly or estimated as the output of a neural network.

The derivative of this parameterization defines the rate of change of the standard deviations with respect rhos. When combined with a so-called "noise loss", used to penalize the standard deviations from deviating from the target distribution, and a task loss (i.e. classification, token prediction), we define a complete training objective and loss landscape for transform layers.

Added in version 0.11.0.

forward abstractmethod

forward(rhos: Tensor) -> torch.Tensor

Apply the parameterization to a rhos tensor.

Parameters:

Name Type Description Default
rhos Tensor

A rhos tensor.

required

Returns:

Type Description
torch.Tensor

A standard deviation tensor.

inverse abstractmethod

inverse(std: Tensor) -> torch.Tensor

Apply the inverse parameterization to a standard deviation tensor.

Parameters:

Name Type Description Default
std Tensor

A standard deviation tensor.

required

Returns:

Type Description
torch.Tensor

A rhos tensor.