module
Functions:
Name | Description |
---|---|
clone_module |
Clone a module, copying its parameters, buffers, submodules, and hooks. |
eval_mode |
Put the entire model into eval mode and disable gradient calculation and upon exiting, restore the original training mode and |
freeze |
Recursively freeze all parameters, enable eval mode, and set affine and track_running_stats to |
infer_minimal_submodules |
Infer the minimal set of submodule names required for a forward pass. |
set_module_attr |
Set a submodule, buffer, or parameter by its dot-delimited name, e.g. 'features.0.conv'. |
temporarily_offload |
Temporarily offload the module's parameters and buffers to the cpu. |
temporarily_remove_hooks |
Temporarily remove all hooks from a module and its submodules when inside this context, and restore the hooks to their original |
train_mode |
Put the entire model into the desired training mode and upon exiting, restore the original training mode. |
clone_module
¶
clone_module(
module: Module,
device: device | str | int | None = None,
detach: bool = True,
ignore_parameters: Collection[str] | None = None,
ignore_modules: Collection[str] | None = None,
) -> Module
Clone a module, copying its parameters, buffers, submodules, and hooks.
Note
When you call clone_module
on a module which contains GPU tensors, those tensors will be loaded to GPU by default. You can call
clone_module(.., device='cpu')
to avoid GPU RAM surge when loading a model checkpoint.
Caution
Specifying ignore_parameters
or ignore_modules
will remove the specified parameters or submodules from the cloned module. This
will usually break the cloned module's functionality, unless you modify the clone module's forward
in some way. Use these
arguments with caution.
Warning
If a module has duplicate parameters/submodules (i.e. a single parameter/submodule that can be accessed by multiple names) when
specifying ignore_parameters
or ignore_modules
, the parameter or submodule will be removed only under the specified name(s).
Warning
If a module has a hook that uses a closure with a reference to the original module, the cloned module's hook will still reference the original module. In this case, it is probably better to temporarily remove the hook, clone the module, and then re-attach the hook to the clone.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The module to clone. |
required |
|
device | str | int | None
|
The device to cast the cloned module to. If None, the cloned module will be cast to the same device as the original module. |
None
|
|
bool
|
Whether to detach the clone module's parameters from the computational graph. If False, then this operation can be backpropagated through. If True, the clone module will be disconnected from the computational graph, but a more memory-efficient strategy will be used to clone the module, avoiding large memory surges on the source device. |
True
|
|
Collection[str] | None
|
A list of parameter names to ignore when cloning the module. These parameter names should be relative to the module. |
None
|
|
Collection[str] | None
|
A list of submodule names to ignore when cloning the module. |
None
|
Returns:
Type | Description |
---|---|
Module
|
The cloned module. |
Examples:
Cloning a module to the same device:
The cloned module is not the same object as the original module:
>>> cloned_module is module
False
>>> cloned_module.weight is module.weight
False
>>> cloned_module.bias is module.bias
False
The cloned module, does, however, have parameters with the same values as the original module:
>>> torch.allclose(cloned_module.weight, module.weight)
True
>>> torch.allclose(cloned_module.bias, module.bias)
True
Cloning a module to a different device:
>>> cloned_module = clone_module(module, device="cpu")
>>> cloned_module.weight.device == torch.device("cpu")
True
Cloning a module while ignoring parameters:
>>> class ComplexModule(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear1 = nn.Linear(10, 10)
... self.relu = nn.ReLU()
... self.linear2 = nn.Linear(10, 10)
>>> module = ComplexModule()
>>> cloned_module_without_linear1_params = clone_module(
... module, ignore_parameters=["linear1.weight", "linear1.bias"]
... )
>>> assert cloned_module_without_linear1_params.linear1.weight is None
>>> assert cloned_module_without_linear1_params.linear1.bias is None
>>>
>>> cloned_module_without_linear1_modules = clone_module(
... module, ignore_modules=["linear1"]
... )
>>> assert not hasattr(cloned_module_without_linear1_modules, "linear1")
Added in version 0.64.0.
eval_mode
¶
freeze
¶
Recursively freeze all parameters, enable eval mode, and set affine and track_running_stats to False for
_BatchNorm` submodules.
Examples:
Freeze a model:
The model is now in eval mode:
>>> model[0].weight.requires_grad
False
>>> model[0].bias.requires_grad
False
>>> model[1].weight.requires_grad
False
>>> model[1].bias.requires_grad
False
The model's batch norm submodule is also put in eval mode, and its affine and track_running_stats attributes are set to False:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The model to freeze. |
required |
infer_minimal_submodules
¶
infer_minimal_submodules(
module: Module,
*module_forward_args: Any,
**module_forward_kwargs: Any,
) -> list[str]
Infer the minimal set of submodule names required for a forward pass.
For most modules, this function will just return a list of all of the names of the submodules of the given module. This is because
most modules will require all of their submodules to be called in order to perform a forward pass. However, some modules may be
designed to only require a subset of their submodules to be called in order to perform a forward pass. One notable example is the
TruncatedModule
class, which is used to truncate the forward pass of a
module after a certain submodule has been called.
Note
This internally calls the module's forward method, so it is not guaranteed to be side-effect free. However, this function does not itself directly modify the module, its state, or cause any side effects.
Note
This function assumes that the module has a static computational graph (i.e. which submodules are called does not depend on the input). This function returns only the names of the submodules required for the given inputs. Any dynamic graph behavior will not be captured.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The module to infer the minimal submodules for. |
required |
|
Any
|
The positional arguments to pass to the module's forward method. |
required |
|
Any
|
The keyword arguments to pass to the module's forward method. |
required |
Returns:
Type | Description |
---|---|
list[str]
|
The minimal set of submodule names required for a forward pass. |
Examples:
Define a module that only requires a subset of its submodules to be called in order to perform a forward pass:
>>> from stainedglass_core import model as sg_model
>>> full_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
>>> truncated_model = sg_model.TruncatedModule(
... full_model, truncation_point=full_model[1]
... )
>>> input = torch.rand(10)
The full model requires all of its submodules to be called in order to perform a forward pass:
The truncated model only requires the first two submodules to be called in order to perform a forward pass:
Added in version 0.70.0.
set_module_attr
¶
Set a submodule, buffer, or parameter by its dot-delimited name, e.g. 'features.0.conv'.
Inspired by nn.Module.get_parameter.
Examples:
Set a submodule:
>>> model = torch.nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
>>> target_module = torch.nn.BatchNorm1d(10, affine=False)
>>> model[1] is target_module
False
>>> set_module_attr(model, "1", target_module)
>>> model[1] is target_module
True
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The source module. |
required |
|
str
|
The dot-delimited name of the submodule, buffer, or parameter to set. |
required |
|
Module | Tensor
|
The value to set. |
required |
Raises:
Type | Description |
---|---|
AttributeError
|
If the target string references an invalid path. |
temporarily_offload
¶
temporarily_remove_hooks
¶
Temporarily remove all hooks from a module and its submodules when inside this context, and restore the hooks to their original state upon exiting.
Note
Adding new hooks to or removing existing ones from (using their RemovableHandle.remove()
methods) the module (or its submodules)
has undefined behavior while inside this context. In particular a newly added hook will only be accessible during this context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The module to remove hooks from. |
required |
Examples:
Temporarily remove hooks from a module:
>>> model = nn.Sequential(nn.Linear(10, 10))
>>> model.register_forward_hook(
... lambda module, input, output: print("Forward hook called")
... )
<...RemovableHandle object at ...>
>>> model(torch.rand(10))
Forward hook called
tensor(...)
>>> with temporarily_remove_hooks(model):
... model(torch.rand(10))
tensor(...)
>>> model(torch.rand(10))
Forward hook called
tensor(...)
Changed in version 0.60.0: Hooks with kwargs are stored in a different attribute which only exists for pytorch versions 2.0 and greater.
train_mode
¶
Put the entire model into the desired training mode and upon exiting, restore the original training mode.
If mode
is False
, gradient calculation is disabled, otherwise is is not affected.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The module to recursively set the training mode of. |
required |
|
bool
|
whether to set training mode ( |
True
|