Skip to content

optim

Classes:

Name Description
Freeze

Freezes submodules and parameters using regex patterns.

ParamGroupBuilder

Configures optimizer parameter group construction.

Unfreeze

Unfreezes submodules and parameters using regex patterns.

Freeze dataclass

Freezes submodules and parameters using regex patterns.

Use this class when it is more concise to enumerate the submodules and parameters you don't want to train.

Examples:

Freeze a single layer:

>>> module = nn.Sequential(
...     nn.Linear(10, 5), nn.BatchNorm1d(5), nn.Sigmoid(), nn.Linear(5, 2)
... )
>>> freeze = Freeze(["3"])
>>> freeze(module)
>>> [name for name, param in module.named_parameters() if not param.requires_grad]
['3.weight', '3.bias']

Methods:

Name Description
__call__

Freeze submodules or parameters whose names match any pattern in patterns.

Attributes:

Name Type Description
patterns list[str]

Patterns matching the names of submodules or parameters to freeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

patterns class-attribute instance-attribute

patterns: list[str] = field(default_factory=list)

Patterns matching the names of submodules or parameters to freeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

__call__

__call__(module: Module) -> None

Freeze submodules or parameters whose names match any pattern in patterns.

Parameters:

Name Type Description Default

module

Module

The Module to partially freeze.

required

ParamGroupBuilder dataclass

Configures optimizer parameter group construction.

Parameters:

Name Type Description Default

param_groups

dict[str, dict[str, Any]]

A mapping of regex patterns matching submodules or parameters to optimizer parameter group keyword arguments.

dict()

freeze

Freeze | Unfreeze

Configuration for freezing submodules and parameters.

Freeze()

Examples:

>>> import torch.optim
>>> model = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> param_group_builder = ParamGroupBuilder(
...     param_groups={
...         "*.weight": {"lr": 1e-3},
...     },
...     freeze=Unfreeze(["0"]),
... )
>>> optimizer = torch.optim.AdamW(param_group_builder(model), lr=1e-4)

Methods:

Name Description
__call__

Apply the Freeze configuration and build optimizer parameter groups using regex.

__call__

__call__(module: Module) -> list[dict[str, Any]]

Apply the Freeze configuration and build optimizer parameter groups using regex.

Parameters:

Name Type Description Default

module

Module

The module to match submodules and parameters against.

required

Raises:

Type Description
ValueError

If a pattern in param_groups does not match any submodules or parameters of module.

ValueError

If more than one pattern in matches one parameter of module.

Returns:

Type Description
list[dict[str, Any]]

A list of optimizer parameter groups.

Unfreeze dataclass

Unfreezes submodules and parameters using regex patterns.

Use this class when it is more concise to enumerate the submodules and parameters you want to train.

Examples:

Only train the bias:

>>> module = nn.Sequential(
...     nn.Linear(10, 5), nn.BatchNorm1d(5), nn.Sigmoid(), nn.Linear(5, 2)
... )
>>> unfreeze = Unfreeze(["*.bias"])
>>> unfreeze(module)
>>> [name for name, param in module.named_parameters() if param.requires_grad]
['0.bias', '1.bias', '3.bias']

Methods:

Name Description
__call__

Unfreeze submodules or parameters whose names match any pattern in patterns.

Attributes:

Name Type Description
patterns list[str]

Patterns matching the names of submodules or parameters to unfreeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

patterns class-attribute instance-attribute

patterns: list[str] = field(default_factory=list)

Patterns matching the names of submodules or parameters to unfreeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

__call__

__call__(module: Module) -> None

Unfreeze submodules or parameters whose names match any pattern in patterns.

Parameters:

Name Type Description Default

module

Module

The Module to partially unfreeze.

required