Skip to content

module

Functions:

Name Description
clone_module

Clone a module, copying its parameters, buffers, submodules, and hooks.

eval_mode

Put the entire model into eval mode and disable gradient calculation and upon exiting, restore the original training mode and

freeze

Recursively freeze all parameters, enable eval mode, and set affine and track_running_stats to False for_BatchNorm` submodules.

infer_minimal_submodules

Infer the minimal set of submodule names required for a forward pass.

set_module_attr

Set a submodule, buffer, or parameter by its dot-delimited name, e.g. 'features.0.conv'.

temporarily_offload

Temporarily offload the module's parameters and buffers to the cpu.

temporarily_remove_hooks

Temporarily remove all hooks from a module and its submodules when inside this context, and restore the hooks to their original

train_mode

Put the entire model into the desired training mode and upon exiting, restore the original training mode.

clone_module

clone_module(
    module: Module,
    device: device | str | int | None = None,
    detach: bool = True,
    ignore_parameters: Collection[str] | None = None,
    ignore_modules: Collection[str] | None = None,
) -> Module

Clone a module, copying its parameters, buffers, submodules, and hooks.

Note

When you call clone_module on a module which contains GPU tensors, those tensors will be loaded to GPU by default. You can call clone_module(.., device='cpu') to avoid GPU RAM surge when loading a model checkpoint.

Caution

Specifying ignore_parameters or ignore_modules will remove the specified parameters or submodules from the cloned module. This will usually break the cloned module's functionality, unless you modify the clone module's forward in some way. Use these arguments with caution.

Warning

If a module has duplicate parameters/submodules (i.e. a single parameter/submodule that can be accessed by multiple names) when specifying ignore_parameters or ignore_modules, the parameter or submodule will be removed only under the specified name(s).

Warning

If a module has a hook that uses a closure with a reference to the original module, the cloned module's hook will still reference the original module. In this case, it is probably better to temporarily remove the hook, clone the module, and then re-attach the hook to the clone.

Parameters:

Name Type Description Default

module

Module

The module to clone.

required

device

device | str | int | None

The device to cast the cloned module to. If None, the cloned module will be cast to the same device as the original module.

None

detach

bool

Whether to detach the clone module's parameters from the computational graph. If False, then this operation can be backpropagated through. If True, the clone module will be disconnected from the computational graph, but a more memory-efficient strategy will be used to clone the module, avoiding large memory surges on the source device.

True

ignore_parameters

Collection[str] | None

A list of parameter names to ignore when cloning the module. These parameter names should be relative to the module.

None

ignore_modules

Collection[str] | None

A list of submodule names to ignore when cloning the module.

None

Returns:

Type Description
Module

The cloned module.

Examples:

Cloning a module to the same device:

>>> module = nn.Linear(10, 10)
>>> cloned_module = clone_module(module)

The cloned module is not the same object as the original module:

>>> cloned_module is module
False
>>> cloned_module.weight is module.weight
False
>>> cloned_module.bias is module.bias
False

The cloned module, does, however, have parameters with the same values as the original module:

>>> torch.allclose(cloned_module.weight, module.weight)
True
>>> torch.allclose(cloned_module.bias, module.bias)
True

Cloning a module to a different device:

>>> cloned_module = clone_module(module, device="cpu")
>>> cloned_module.weight.device == torch.device("cpu")
True

Cloning a module while ignoring parameters:

>>> class ComplexModule(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear1 = nn.Linear(10, 10)
...         self.relu = nn.ReLU()
...         self.linear2 = nn.Linear(10, 10)
>>> module = ComplexModule()
>>> cloned_module_without_linear1_params = clone_module(
...     module, ignore_parameters=["linear1.weight", "linear1.bias"]
... )
>>> assert cloned_module_without_linear1_params.linear1.weight is None
>>> assert cloned_module_without_linear1_params.linear1.bias is None
>>>
>>> cloned_module_without_linear1_modules = clone_module(
...     module, ignore_modules=["linear1"]
... )
>>> assert not hasattr(cloned_module_without_linear1_modules, "linear1")

Added in version 0.64.0.

eval_mode

eval_mode(module: Module) -> Generator[None]

Put the entire model into eval mode and disable gradient calculation and upon exiting, restore the original training mode and re-enable gradient calculation.

Parameters:

Name Type Description Default

module

Module

The module to recursively set the training mode of.

required

freeze

freeze(model: Module) -> None

Recursively freeze all parameters, enable eval mode, and set affine and track_running_stats to False for_BatchNorm` submodules.

Examples:

Freeze a model:

>>> model = nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
>>> freeze(model)

The model is now in eval mode:

>>> model[0].weight.requires_grad
False
>>> model[0].bias.requires_grad
False
>>> model[1].weight.requires_grad
False
>>> model[1].bias.requires_grad
False

The model's batch norm submodule is also put in eval mode, and its affine and track_running_stats attributes are set to False:

>>> model[1].affine
False
>>> model[1].track_running_stats
False

Parameters:

Name Type Description Default

model

Module

The model to freeze.

required

infer_minimal_submodules

infer_minimal_submodules(
    module: Module,
    *module_forward_args: Any,
    **module_forward_kwargs: Any,
) -> list[str]

Infer the minimal set of submodule names required for a forward pass.

For most modules, this function will just return a list of all of the names of the submodules of the given module. This is because most modules will require all of their submodules to be called in order to perform a forward pass. However, some modules may be designed to only require a subset of their submodules to be called in order to perform a forward pass. One notable example is the TruncatedModule class, which is used to truncate the forward pass of a module after a certain submodule has been called.

Note

This internally calls the module's forward method, so it is not guaranteed to be side-effect free. However, this function does not itself directly modify the module, its state, or cause any side effects.

Note

This function assumes that the module has a static computational graph (i.e. which submodules are called does not depend on the input). This function returns only the names of the submodules required for the given inputs. Any dynamic graph behavior will not be captured.

Parameters:

Name Type Description Default

module

Module

The module to infer the minimal submodules for.

required

*module_forward_args

Any

The positional arguments to pass to the module's forward method.

required

**module_forward_kwargs

Any

The keyword arguments to pass to the module's forward method.

required

Returns:

Type Description
list[str]

The minimal set of submodule names required for a forward pass.

Examples:

Define a module that only requires a subset of its submodules to be called in order to perform a forward pass:

>>> from stainedglass_core import model as sg_model
>>> full_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
>>> truncated_model = sg_model.TruncatedModule(
...     full_model, truncation_point=full_model[1]
... )
>>> input = torch.rand(10)

The full model requires all of its submodules to be called in order to perform a forward pass:

>>> infer_minimal_submodules(full_model, input)
['0', '1', '2']

The truncated model only requires the first two submodules to be called in order to perform a forward pass:

>>> infer_minimal_submodules(truncated_model, input)
['module.0', 'module.1']

Added in version 0.70.0.

set_module_attr

set_module_attr(
    module: Module, target: str, value: Module | Tensor
) -> None

Set a submodule, buffer, or parameter by its dot-delimited name, e.g. 'features.0.conv'.

Inspired by nn.Module.get_parameter.

Examples:

Set a submodule:

>>> model = torch.nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
>>> target_module = torch.nn.BatchNorm1d(10, affine=False)
>>> model[1] is target_module
False
>>> set_module_attr(model, "1", target_module)
>>> model[1] is target_module
True

Parameters:

Name Type Description Default

module

Module

The source module.

required

target

str

The dot-delimited name of the submodule, buffer, or parameter to set.

required

value

Module | Tensor

The value to set.

required

Raises:

Type Description
AttributeError

If the target string references an invalid path.

temporarily_offload

temporarily_offload(module: Module) -> Generator[None]

Temporarily offload the module's parameters and buffers to the cpu.

Parameters:

Name Type Description Default

module

Module

The module to offload.

required

temporarily_remove_hooks

temporarily_remove_hooks(module: Module) -> Generator[None]

Temporarily remove all hooks from a module and its submodules when inside this context, and restore the hooks to their original state upon exiting.

Note

Adding new hooks to or removing existing ones from (using their RemovableHandle.remove() methods) the module (or its submodules) has undefined behavior while inside this context. In particular a newly added hook will only be accessible during this context.

Parameters:

Name Type Description Default

module

Module

The module to remove hooks from.

required

Examples:

Temporarily remove hooks from a module:

>>> model = nn.Sequential(nn.Linear(10, 10))
>>> model.register_forward_hook(
...     lambda module, input, output: print("Forward hook called")
... )
<...RemovableHandle object at ...>
>>> model(torch.rand(10))
Forward hook called
tensor(...)
>>> with temporarily_remove_hooks(model):
...     model(torch.rand(10))
tensor(...)
>>> model(torch.rand(10))
Forward hook called
tensor(...)

Changed in version 0.60.0: Hooks with kwargs are stored in a different attribute which only exists for pytorch versions 2.0 and greater.

train_mode

train_mode(
    module: Module, mode: bool = True
) -> Generator[None]

Put the entire model into the desired training mode and upon exiting, restore the original training mode.

If mode is False, gradient calculation is disabled, otherwise is is not affected.

Parameters:

Name Type Description Default

module

Module

The module to recursively set the training mode of.

required

mode

bool

whether to set training mode (True) or evaluation mode (False).

True