Skip to content

precision

Classes:

Name Description
DesiredDtypePrecision

Protocol for Precision subclasses that have the convert_module method and the attribute _desired_dtype.

DesiredInputDtypePrecision

Protocol for Precision subclasses that have the convert_module method and the attribute _desired_input_dtype.

ReducedPrecisionFilter

Wraps any lightning.fabric.plugins.precision.Precision object, patching its convert_module to avoid casting certain submodules'

DesiredDtypePrecision

Bases: Protocol

Protocol for Precision subclasses that have the convert_module method and the attribute _desired_dtype.

TODO: update DeepSpeedPrecision in a PR to lightning to conform to DesiredInputDtypePrecision.

DesiredInputDtypePrecision

Bases: Protocol

Protocol for Precision subclasses that have the convert_module method and the attribute _desired_input_dtype.

ReducedPrecisionFilter

Bases: Generic[_PrecisionT], Precision, Precision

Wraps any lightning.fabric.plugins.precision.Precision object, patching its convert_module to avoid casting certain submodules' parameters and buffers to reduced-precision to avoid convergence issues, namely those of normalization layers and stainedglass_core.noise_layer.BaseNoiseLayer.

Some classes may not converge well in lower precision: https://discuss.pytorch.org/t/training-with-half-precision/11815/2.

Methods:

Name Description
__init__

Construct a ReducedPrecisionFilter object.

convert_module

Cast the parameters and buffers of the given module to the desired dtype, avoiding certain submodules, parameters, and buffers.

__init__

__init__(
    _precision: _PrecisionT,
    full_precision_module_types: Iterable[type[Module]]
    | None = None,
    full_precision_names: Iterable[str] | None = None,
) -> None

Construct a ReducedPrecisionFilter object.

Parameters:

Name Type Description Default

_precision

_PrecisionT

The Precision object to wrap.

required

full_precision_module_types

Iterable[type[Module]] | None

Additional nn.Module types whose parameters and buffers to avoid casting to reduced-precision.

None

full_precision_names

Iterable[str] | None

Regex patterns matching submodule, parameter, or buffer names to avoid casting to reduced-precision. See patterns_to_regex for the syntax of the patterns.

None

Raises:

Type Description
TypeError

If _precision is an instance of ReducedPrecisionFilter.

convert_module

convert_module(module: _ModuleT) -> _ModuleT

Cast the parameters and buffers of the given module to the desired dtype, avoiding certain submodules, parameters, and buffers.

Parameters:

Name Type Description Default

module

_ModuleT

The module whose parameters and buffers to cast to the desired dtype.

required

Raises:

Type Description
TypeError

If self does not have the _desired_input_dtype or _desired_dtype attributes.

Returns:

Type Description
_ModuleT

The casted module.