lightning
ReducedPrecisionFilter
¶
Bases: Generic[_PrecisionT]
, Precision
, Precision
Wraps any lightning.fabric.plugins.precision.Precision
object, patching its convert_module
to avoid casting certain submodules'
parameters and buffers to reduced-precision to avoid convergence issues, namely those of normalization layers and
stainedglass_core.noise_layer.BaseNoiseLayer
.
Some classes may not converge well in lower precision: https://discuss.pytorch.org/t/training-with-half-precision/11815/2.
__init__
¶
__init__(_precision: _PrecisionT, full_precision_module_types: Iterable[type[Module]] | None = None, full_precision_names: Iterable[str] | None = None) -> None
Construct a ReducedPrecisionFilter
object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_precision |
_PrecisionT
|
The |
required |
full_precision_module_types |
Iterable[type[Module]] | None
|
Additional |
None
|
full_precision_names |
Iterable[str] | None
|
Regex patterns matching submodule, parameter, or buffer names to avoid casting to reduced-precision. See
|
None
|
Raises:
Type | Description |
---|---|
TypeError
|
If |
convert_module
¶
Cast the parameters and buffers of the given module to the desired dtype, avoiding certain submodules, parameters, and buffers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
_ModuleT
|
The module whose parameters and buffers to cast to the desired dtype. |
required |
Raises:
Type | Description |
---|---|
TypeError
|
If |
Returns:
Type | Description |
---|---|
_ModuleT
|
The casted module. |