Skip to content

module

Functions:

Name Description
clone_module

Clone a module, copying its parameters, buffers, submodules, and hooks.

eval_mode

Put the entire model into eval mode and disable gradient calculation and upon exiting, restore the original training mode and

freeze

Recursively freeze all parameters, enable eval mode, and set affine and track_running_stats to False for_BatchNorm` submodules.

infer_minimal_submodules

Infer the minimal set of submodule names required for a forward pass.

set_module_attr

Set a submodule, buffer, or parameter by its dot-delimited name, e.g. 'features.0.conv'.

temporarily_offload

Temporarily offload the module's parameters and buffers to the cpu.

temporarily_remove_hooks

Temporarily remove all hooks from a module and its submodules when inside this context, and restore the hooks to their original

to_empty

Move the parameters and buffers to the specified device without copying storage if they are not already on that device.

train_mode

Put the entire model into the desired training mode and upon exiting, restore the original training mode.

ParameterMapBuilder

Builds a mapping of the names of the parameters in a source module to the names of the parameters in a target module to tie.

Use this class to create a mapping of parameter names in the source module to the target module. This is useful when you want to tie weights.

Examples:

>>> class SourceNet(nn.Module):
...     def __init__(self, input_size, hidden_size, output_size):
...         super(SourceNet, self).__init__()
...         self.source_fc1 = nn.Linear(input_size, hidden_size)
...         self.relu = nn.ReLU()  # ReLU activation function
...         self.source_fc2 = nn.Linear(hidden_size, output_size)
>>> source_module = SourceNet(10, 1, 2)
>>> for name, _ in source_module.named_parameters():
...     print(name)
source_fc1.weight
source_fc1.bias
source_fc2.weight
source_fc2.bias
>>> class TargetNet(nn.Module):
...     def __init__(self, input_size, hidden_size, output_size):
...         super(TargetNet, self).__init__()
...         self.target_fc1 = nn.Linear(input_size, hidden_size)
...         self.relu = nn.ReLU()
...         self.target_fc2 = nn.Linear(hidden_size, output_size)
>>> target_module = TargetNet(10, 1, 2)
>>> for name, _ in target_module.named_parameters():
...     print(name)
target_fc1.weight
target_fc1.bias
target_fc2.weight
target_fc2.bias
>>> param_mapper = ParameterMapBuilder(
...     include_patterns=["*weight"],
...     exclude_patterns=["*bias"],
...     include_parameters=True,
...     include_buffers=False,
...     target_key_modification=lambda x: x.replace("source", "target"),
... )
>>> param_map = param_mapper(source_module, target_module)
>>> param_mapper = ParameterMapBuilder(
...     include_patterns=["*fc1*", "*fc2*"],
...     exclude_patterns=["*weight"],
...     include_parameters=True,
...     include_buffers=False,
...     target_key_modification=lambda x: x.replace("source", "target"),
... )
>>> param_map = param_mapper(source_module, target_module)

Added in version v0.134.0.

Methods:

Name Description
__call__

Build the mapping of the names of the parameters in a source module to the names of the parameters in a target module.

__init__

Initialize the ParameterMapBuilder.

__call__

__call__(
    source_module: Module, target_module: Module
) -> dict[str, str]

Build the mapping of the names of the parameters in a source module to the names of the parameters in a target module.

Parameters:

Name Type Description Default

source_module

Module

The module to tie the weights of.

required

target_module

Module

The module to tie the weights to.

required

Returns:

Type Description
dict[str, str]

A dictionary mapping the names of the parameters in source_module to the names of the parameters in target_module to tie.

Raises:

Type Description
AttributeError

If the target module does not have the specified parameter name.

__init__

Initialize the ParameterMapBuilder.

Parameters:

Name Type Description Default

include_patterns

Sequence[str] | None

List of regex patterns to include. If None, an empty list is used. For example: ["weight", "bias"].

required

exclude_patterns

Sequence[str] | None

List of patterns to exclude. If None, an empty list is used.

required

target_key_modification

Callable[[str], str] | None

Function to modify target keys with respect to the source_module parameter names.

None

include_parameters

bool

Whether to include parameters. Defaults to True.

True

include_buffers

bool

Whether to include buffers. Defaults to True.

True

validate_target_keys

bool

Whether to validate target keys post target_key_modification. Defaults to True.

True

clone_module

clone_module(
    module: ModuleT,
    device: device | str | int | None = None,
    detach: bool = True,
    ignore_parameters: Collection[str] | None = None,
    ignore_modules: Collection[str] | None = None,
) -> ModuleT

Clone a module, copying its parameters, buffers, submodules, and hooks.

Note

When you call clone_module on a module which contains GPU tensors, those tensors will be loaded to GPU by default. You can call clone_module(.., device='cpu') to avoid GPU RAM surge when loading a model checkpoint.

Caution

Specifying ignore_parameters or ignore_modules will remove the specified parameters or submodules from the cloned module. This will usually break the cloned module's functionality, unless you modify the clone module's forward in some way. Use these arguments with caution.

Warning

If a module has duplicate parameters/submodules (i.e. a single parameter/submodule that can be accessed by multiple names) when specifying ignore_parameters or ignore_modules, the parameter or submodule will be removed only under the specified name(s).

Warning

If a module has a hook that uses a closure with a reference to the original module, the cloned module's hook will still reference the original module. In this case, it is probably better to temporarily remove the hook, clone the module, and then re-attach the hook to the clone.

Parameters:

Name Type Description Default

module

ModuleT

The module to clone.

required

device

device | str | int | None

The device to cast the cloned module to. If None, the cloned module will be cast to the same device as the original module.

None

detach

bool

Whether to detach the clone module's parameters from the computational graph. If False, then this operation can be backpropagated through. If True, the clone module will be disconnected from the computational graph, but a more memory-efficient strategy will be used to clone the module, avoiding large memory surges on the source device.

True

ignore_parameters

Collection[str] | None

A list of parameter names to ignore when cloning the module. These parameter names should be relative to the module.

None

ignore_modules

Collection[str] | None

A list of submodule names to ignore when cloning the module.

None

Returns:

Type Description
ModuleT

The cloned module.

Examples:

Cloning a module to the same device:

>>> module = nn.Linear(10, 10)
>>> cloned_module = clone_module(module)

The cloned module is not the same object as the original module:

>>> cloned_module is module
False
>>> cloned_module.weight is module.weight
False
>>> cloned_module.bias is module.bias
False

The cloned module, does, however, have parameters with the same values as the original module:

>>> torch.allclose(cloned_module.weight, module.weight)
True
>>> torch.allclose(cloned_module.bias, module.bias)
True

Cloning a module to a different device:

>>> cloned_module = clone_module(module, device="cpu")
>>> cloned_module.weight.device == torch.device("cpu")
True

Cloning a module while ignoring parameters:

>>> class ComplexModule(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear1 = nn.Linear(10, 10)
...         self.relu = nn.ReLU()
...         self.linear2 = nn.Linear(10, 10)
>>> module = ComplexModule()
>>> cloned_module_without_linear1_params = clone_module(
...     module, ignore_parameters=["linear1.weight", "linear1.bias"]
... )
>>> assert cloned_module_without_linear1_params.linear1.weight is None
>>> assert cloned_module_without_linear1_params.linear1.bias is None
>>>
>>> cloned_module_without_linear1_modules = clone_module(
...     module, ignore_modules=["linear1"]
... )
>>> assert not hasattr(cloned_module_without_linear1_modules, "linear1")

Added in version 0.64.0.

eval_mode

eval_mode(module: Module) -> Generator[None]

Put the entire model into eval mode and disable gradient calculation and upon exiting, restore the original training mode and re-enable gradient calculation.

Parameters:

Name Type Description Default

module

Module

The module to recursively set the training mode of.

required

freeze

freeze(model: Module) -> None

Recursively freeze all parameters, enable eval mode, and set affine and track_running_stats to False for_BatchNorm` submodules.

Examples:

Freeze a model:

>>> model = nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
>>> freeze(model)

The model is now in eval mode:

>>> model[0].weight.requires_grad
False
>>> model[0].bias.requires_grad
False
>>> model[1].weight.requires_grad
False
>>> model[1].bias.requires_grad
False

The model's batch norm submodule is also put in eval mode, and its affine and track_running_stats attributes are set to False:

>>> model[1].affine
False
>>> model[1].track_running_stats
False

Parameters:

Name Type Description Default

model

Module

The model to freeze.

required

infer_minimal_submodules

infer_minimal_submodules(
    module: Module,
    *module_forward_args: Any,
    **module_forward_kwargs: Any,
) -> list[str]

Infer the minimal set of submodule names required for a forward pass.

For most modules, this function will just return a list of all of the names of the submodules of the given module. This is because most modules will require all of their submodules to be called in order to perform a forward pass. However, some modules may be designed to only require a subset of their submodules to be called in order to perform a forward pass. One notable example is the TruncatedModule class, which is used to truncate the forward pass of a module after a certain submodule has been called.

Note

This internally calls the module's forward method, so it is not guaranteed to be side-effect free. However, this function does not itself directly modify the module, its state, or cause any side effects.

Note

This function assumes that the module has a static computational graph (i.e. which submodules are called does not depend on the input). This function returns only the names of the submodules required for the given inputs. Any dynamic graph behavior will not be captured.

Parameters:

Name Type Description Default

module

Module

The module to infer the minimal submodules for.

required

*module_forward_args

Any

The positional arguments to pass to the module's forward method.

required

**module_forward_kwargs

Any

The keyword arguments to pass to the module's forward method.

required

Returns:

Type Description
list[str]

The minimal set of submodule names required for a forward pass.

Examples:

Define a module that only requires a subset of its submodules to be called in order to perform a forward pass:

>>> from stainedglass_core import model as sg_model
>>> full_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
>>> truncated_model = sg_model.TruncatedModule(
...     full_model, truncation_point=full_model[1]
... )
>>> input = torch.rand(10)

The full model requires all of its submodules to be called in order to perform a forward pass:

>>> infer_minimal_submodules(full_model, input)
['0', '1', '2']

The truncated model only requires the first two submodules to be called in order to perform a forward pass:

>>> infer_minimal_submodules(truncated_model, input)
['module.0', 'module.1']

Added in version 0.70.0.

set_module_attr

set_module_attr(
    module: Module, target: str, value: Module | Tensor
) -> None

Set a submodule, buffer, or parameter by its dot-delimited name, e.g. 'features.0.conv'.

Inspired by nn.Module.get_parameter.

Examples:

Set a submodule:

>>> model = torch.nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
>>> target_module = torch.nn.BatchNorm1d(10, affine=False)
>>> model[1] is target_module
False
>>> set_module_attr(model, "1", target_module)
>>> model[1] is target_module
True

Parameters:

Name Type Description Default

module

Module

The source module.

required

target

str

The dot-delimited name of the submodule, buffer, or parameter to set.

required

value

Module | Tensor

The value to set.

required

Raises:

Type Description
AttributeError

If the target string references an invalid path.

temporarily_offload

temporarily_offload(module: Module) -> Generator[None]

Temporarily offload the module's parameters and buffers to the cpu.

Parameters:

Name Type Description Default

module

Module

The module to offload.

required

temporarily_remove_hooks

temporarily_remove_hooks(module: Module) -> Generator[None]

Temporarily remove all hooks from a module and its submodules when inside this context, and restore the hooks to their original state upon exiting.

Note

Adding new hooks to or removing existing ones from (using their RemovableHandle.remove() methods) the module (or its submodules) has undefined behavior while inside this context. In particular a newly added hook will only be accessible during this context.

Parameters:

Name Type Description Default

module

Module

The module to remove hooks from.

required

Examples:

Temporarily remove hooks from a module:

>>> model = nn.Sequential(nn.Linear(10, 10))
>>> model.register_forward_hook(
...     lambda module, input, output: print("Forward hook called")
... )
<...RemovableHandle object at ...>
>>> model(torch.rand(10))
Forward hook called
tensor(...)
>>> with temporarily_remove_hooks(model):
...     model(torch.rand(10))
tensor(...)
>>> model(torch.rand(10))
Forward hook called
tensor(...)

Changed in version 0.60.0: Hooks with kwargs are stored in a different attribute which only exists for pytorch versions 2.0 and greater.

tie_parameters_and_buffers

tie_parameters_and_buffers(
    source_module: Module,
    target_module: Module,
    source_to_target_map: Mapping[str, str] | None = None,
) -> None

Tie the parameters and buffers of a target_module to that of a source_module.

Parameters:

Name Type Description Default

source_module

Module

The module to tie the weights of.

required

target_module

Module

The module to tie the weights to.

required

source_to_target_map

Mapping[str, str] | None

A dictionary mapping the names of the parameters and buffers in module to the names of the parameters in target_module to tie. If None, all parameters and buffers in module will be tied to the corresponding parameters in target_module and the parameter names for the source and target modules are exactly the same.

None

Raises:

Type Description
AttributeError

If the target module does not have the specified target parameter name in the source_to_target_map.

ValueError

If the shapes of the source and corresponding target parameter shapes in the source_to_target_map do not match.

ValueError

If a source parameter is on a meta device.

ValueError

If a target parameter is on a meta device.

Examples:

Tie the weights of two modules:

>>> source_module = nn.Linear(10, 10)
>>> target_module = nn.Linear(10, 10)
>>> tie_parameters_and_buffers(source_module, target_module)

Added in version v0.134.0.

to_empty

to_empty(
    module: ModuleT,
    *,
    device: device | str | int | None,
    recurse: bool = True,
) -> ModuleT

Move the parameters and buffers to the specified device without copying storage if they are not already on that device.

See: https://github.com/pytorch/pytorch/pull/148926.

Parameters:

Name Type Description Default

module

ModuleT

The module whose parameters and buffers to (maybe) move.

required

device

device | str | int | None

The desired device of the parameters and buffers in the module. If None, the default device is used.

required

recurse

bool

Whether parameters and buffers of submodules should be recursively moved to the specified device.

True

Returns:

Type Description
ModuleT

The (maybe) moved module.

train_mode

train_mode(
    module: Module, mode: bool = True
) -> Generator[None]

Put the entire model into the desired training mode and upon exiting, restore the original training mode.

If mode is False, gradient calculation is disabled, otherwise is is not affected.

Parameters:

Name Type Description Default

module

Module

The module to recursively set the training mode of.

required

mode

bool

whether to set training mode (True) or evaluation mode (False).

True