lightning
ReducedPrecisionFilter
¶
Bases: Generic[_PrecisionT]
, Precision
, Precision
Wraps any lightning.fabric.plugins.precision.Precision
object, patching its convert_module
to avoid casting certain submodules'
parameters and buffers to reduced-precision to avoid convergence issues, namely those of normalization layers and
stainedglass_core.noise_layer.BaseNoiseLayer
.
Some classes may not converge well in lower precision: https://discuss.pytorch.org/t/training-with-half-precision/11815/2.
__init__
¶
__init__(
_precision: _PrecisionT,
full_precision_module_types: Iterable[type[Module]]
| None = None,
full_precision_names: Iterable[str] | None = None,
) -> None
Construct a ReducedPrecisionFilter
object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
_PrecisionT
|
The |
required |
|
Iterable[type[Module]] | None
|
Additional |
None
|
|
Iterable[str] | None
|
Regex patterns matching submodule, parameter, or buffer names to avoid casting to reduced-precision. See
|
None
|
Raises:
Type | Description |
---|---|
TypeError
|
If |
convert_module
¶
convert_module(module: _ModuleT) -> _ModuleT
Cast the parameters and buffers of the given module to the desired dtype, avoiding certain submodules, parameters, and buffers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
_ModuleT
|
The module whose parameters and buffers to cast to the desired dtype. |
required |
Raises:
Type | Description |
---|---|
TypeError
|
If |
Returns:
Type | Description |
---|---|
_ModuleT
|
The casted module. |