Skip to content

lightning

ReducedPrecisionFilter

Bases: Generic[_PrecisionT], Precision, Precision

Wraps any lightning.fabric.plugins.precision.Precision object, patching its convert_module to avoid casting certain submodules' parameters and buffers to reduced-precision to avoid convergence issues, namely those of normalization layers and stainedglass_core.noise_layer.BaseNoiseLayer.

Some classes may not converge well in lower precision: https://discuss.pytorch.org/t/training-with-half-precision/11815/2.

__init__

__init__(_precision: _PrecisionT, full_precision_module_types: Iterable[type[Module]] | None = None, full_precision_names: Iterable[str] | None = None) -> None

Construct a ReducedPrecisionFilter object.

Parameters:

Name Type Description Default
_precision _PrecisionT

The Precision object to wrap.

required
full_precision_module_types Iterable[type[Module]] | None

Additional nn.Module types whose parameters and buffers to avoid casting to reduced-precision.

None
full_precision_names Iterable[str] | None

Regex patterns matching submodule, parameter, or buffer names to avoid casting to reduced-precision. See patterns_to_regex for the syntax of the patterns.

None

Raises:

Type Description
TypeError

If _precision is an instance of ReducedPrecisionFilter.

convert_module

convert_module(module: _ModuleT) -> _ModuleT

Cast the parameters and buffers of the given module to the desired dtype, avoiding certain submodules, parameters, and buffers.

Parameters:

Name Type Description Default
module _ModuleT

The module whose parameters and buffers to cast to the desired dtype.

required

Raises:

Type Description
TypeError

If self does not have the _desired_input_dtype or _desired_dtype attributes.

Returns:

Type Description
_ModuleT

The casted module.