lightning
Modules:
Name | Description |
---|---|
precision |
|
Classes:
Name | Description |
---|---|
ReducedPrecisionFilter |
Wraps any |
ReducedPrecisionFilter
¶
Bases: Generic[_PrecisionT]
, Precision
, Precision
Wraps any lightning.fabric.plugins.precision.Precision
object, patching its convert_module
to avoid casting certain submodules'
parameters and buffers to reduced-precision to avoid convergence issues, namely those of normalization layers and
stainedglass_core.noise_layer.BaseNoiseLayer
.
Some classes may not converge well in lower precision: https://discuss.pytorch.org/t/training-with-half-precision/11815/2.
Methods:
Name | Description |
---|---|
__init__ |
Construct a |
convert_module |
Cast the parameters and buffers of the given module to the desired dtype, avoiding certain submodules, parameters, and buffers. |
__init__
¶
__init__(
_precision: _PrecisionT,
full_precision_module_types: Iterable[type[Module]]
| None = None,
full_precision_names: Iterable[str] | None = None,
) -> None
Construct a ReducedPrecisionFilter
object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
_PrecisionT
|
The |
required |
|
Iterable[type[Module]] | None
|
Additional |
None
|
|
Iterable[str] | None
|
Regex patterns matching submodule, parameter, or buffer names to avoid casting to reduced-precision. See
|
None
|
Raises:
Type | Description |
---|---|
TypeError
|
If |
convert_module
¶
convert_module(module: _ModuleT) -> _ModuleT
Cast the parameters and buffers of the given module to the desired dtype, avoiding certain submodules, parameters, and buffers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
_ModuleT
|
The module whose parameters and buffers to cast to the desired dtype. |
required |
Raises:
Type | Description |
---|---|
TypeError
|
If |
Returns:
Type | Description |
---|---|
_ModuleT
|
The casted module. |