Skip to content

torch

Modules:

Name Description
dtypes
hooks

Utility functions for hook argument and output manipulation.

init
knn
module
normalize
optim
projection
sequence
tensor
tensorboard

Freeze dataclass

Freezes submodules and parameters using regex patterns.

Use this class when it is more concise to enumerate the submodules and parameters you don't want to train.

Examples:

Freeze a single layer:

>>> module = nn.Sequential(
...     nn.Linear(10, 5), nn.BatchNorm1d(5), nn.Sigmoid(), nn.Linear(5, 2)
... )
>>> freeze = Freeze(["3"])
>>> freeze(module)
>>> [name for name, param in module.named_parameters() if not param.requires_grad]
['3.weight', '3.bias']

Methods:

Name Description
__call__

Freeze submodules or parameters whose names match any pattern in patterns.

Attributes:

Name Type Description
patterns list[str]

Patterns matching the names of submodules or parameters to freeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

patterns class-attribute instance-attribute

patterns: list[str] = field(default_factory=list)

Patterns matching the names of submodules or parameters to freeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

__call__

__call__(module: Module) -> None

Freeze submodules or parameters whose names match any pattern in patterns.

Parameters:

Name Type Description Default

module

Module

The Module to partially freeze.

required

ParamGroupBuilder dataclass

Configures optimizer parameter group construction.

Parameters:

Name Type Description Default

param_groups

dict[str, dict[str, Any]]

A mapping of regex patterns matching submodules or parameters to optimizer parameter group keyword arguments.

dict()

freeze

Freeze | Unfreeze

Configuration for freezing submodules and parameters.

Freeze()

Examples:

>>> import torch.optim
>>> model = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> param_group_builder = ParamGroupBuilder(
...     param_groups={
...         "*.weight": {"lr": 1e-3},
...     },
...     freeze=Unfreeze(["0"]),
... )
>>> optimizer = torch.optim.AdamW(param_group_builder(model), lr=1e-4)

Methods:

Name Description
__call__

Apply the Freeze configuration and build optimizer parameter groups using regex.

__call__

__call__(module: Module) -> list[dict[str, Any]]

Apply the Freeze configuration and build optimizer parameter groups using regex.

Parameters:

Name Type Description Default

module

Module

The module to match submodules and parameters against.

required

Raises:

Type Description
ValueError

If a pattern in param_groups does not match any submodules or parameters of module.

ValueError

If more than one pattern in matches one parameter of module.

Returns:

Type Description
list[dict[str, Any]]

A list of optimizer parameter groups.

Unfreeze dataclass

Unfreezes submodules and parameters using regex patterns.

Use this class when it is more concise to enumerate the submodules and parameters you want to train.

Examples:

Only train the bias:

>>> module = nn.Sequential(
...     nn.Linear(10, 5), nn.BatchNorm1d(5), nn.Sigmoid(), nn.Linear(5, 2)
... )
>>> unfreeze = Unfreeze(["*.bias"])
>>> unfreeze(module)
>>> [name for name, param in module.named_parameters() if param.requires_grad]
['0.bias', '1.bias', '3.bias']

Methods:

Name Description
__call__

Unfreeze submodules or parameters whose names match any pattern in patterns.

Attributes:

Name Type Description
patterns list[str]

Patterns matching the names of submodules or parameters to unfreeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

patterns class-attribute instance-attribute

patterns: list[str] = field(default_factory=list)

Patterns matching the names of submodules or parameters to unfreeze; e.g. 'base_model', 'noise_layer.*', '*.bias'.

__call__

__call__(module: Module) -> None

Unfreeze submodules or parameters whose names match any pattern in patterns.

Parameters:

Name Type Description Default

module

Module

The Module to partially unfreeze.

required

batched_knn

batched_knn(
    embedding_index: Tensor,
    query: Tensor,
    k: int,
    p: int,
    max_batch_size: int | None = None,
    max_sequence_length: int | None = None,
    max_num_embeddings: int | None = None,
) -> torch.Tensor

Find the k-nearest neighbors of a query tensor in an embedding index, processing at most a (max_batch_size, max_sequence_length) section of the query a time.

Smaller values of max_batch_size, max_sequence_length, and max_num_embeddings require less memory to store the intermediate distance calculations but have longer runtimes.

Parameters:

Name Type Description Default

embedding_index

Tensor

A tensor of shape (n_embeddings, embedding_dim).

required

query

Tensor

A tensor of shape (batch_size, sequence_length, embedding_dim) or (sequence_length, embedding_dim).

required

k

int

The number of nearest neighbors to find.

required

p

int

The p-norm to use for the distance calculation.

required

max_batch_size

int | None

The maximum number of batch elements over which to calculate distances.

None

max_sequence_length

int | None

The maximum number of sequence positions over which to calculate distances.

None

max_num_embeddings

int | None

The maximum number of embeddings over which to calculate distances. The results from each split are recursively merged together.

None

Returns:

Type Description
torch.Tensor

A tensor of shape (batch_size, sequence_length, k) or (sequence_length, k) containing the indices of the k-nearest neighbors of each

torch.Tensor

query.

Raises:

Type Description
ValueError

If the input tensors are of dtype bfloat16.

batchwise_min_max_normalize

batchwise_min_max_normalize(tensor: Tensor) -> <class 'torch.Tensor'>

Normalize a tensor by its min and max values within each batch element. This is equivalent to using min_max_normalize in a loop over the batch dimension.

Examples:

>>> tensor = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.float32)
>>> batchwise_min_max_normalize(tensor)
tensor([[0.0000, 0.5000, 1.0000],
        [0.0000, 0.5000, 1.0000]])

Parameters:

Name Type Description Default

tensor

Tensor

The tensor to normalize.

required

Returns:

Type Description
<class 'torch.Tensor'>

The tensor, normalized batchwise.

Note

If the min and max values are the same along the batch dimension, nan values will be returned for that batch element.

build_attention_mask

build_attention_mask(
    sequences: list[Tensor] | tuple[Tensor, ...],
    padding_side: Literal["left", "right"] | str = "right",
    pad_to_multiple_of: int | None = None,
    max_length: int | None = None,
) -> torch.Tensor

Build an attention mask for the given sequences.

Parameters:

Name Type Description Default

sequences

list[Tensor] | tuple[Tensor, ...]

A list or tuple of 1D tensors of shape (sequence_length,).

required

padding_side

Literal['left', 'right'] | str

The side to pad the inputs. Either 'left' or 'right'.

'right'

pad_to_multiple_of

int | None

The multiple length to pad the sequences to.

None

max_length

int | None

The maximum length after which sequences will be truncated.

None

Raises:

Type Description
ValueError

If padding_side is not one of 'left' or 'right'.

Returns:

Type Description
torch.Tensor

A 2D boolean attention mask tensor of shape (batch_size, max_seq_length).

Added in version 0.84.0.

calculate_equivalent_alpha_and_scaling_factor

calculate_equivalent_alpha_and_scaling_factor(
    noise_loss_grad: Tensor, model_loss_grad: Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Calculate the equivalent alpha and scaling factor (compared to the composite loss) for an alphaless gradient.

This gives us a closed form solution for what alpha would be within the alphaless framework, which could be handy in trying to understand noise layer training.

Parameters:

Name Type Description Default

noise_loss_grad

Tensor

The gradient of the noise loss, with respect to the trained parameters.

required

model_loss_grad

Tensor

The gradient of the model loss, with respect to the trained parameters.

required

Returns:

Type Description
tuple[torch.Tensor, torch.Tensor]

A tuple of the equivalent alpha and scaling factor.

cast_to_device

cast_to_device(
    value: T,
    *,
    device: str | int | device | None = None,
    dtype: dtype | None = None,
) -> T

Make a deep copy of value, casting all tensors to the given device and dtype.

Adapted from: https://github.com/pytorch/pytorch/blob/49444c3e546bf240bed24a101e747422d1f8a0ee/torch/optim/optimizer.py#L209C1-L225C29.

Parameters:

Name Type Description Default

value

T

The value to recursively copy and cast.

required

device

str | int | device | None

The device to cast tensors to.

None

dtype

dtype | None

The dtype to cast tensors. Only applied to floating point tensors.

None

Returns:

Type Description
T

The copied and casted value.

clone_module

clone_module(
    module: Module,
    device: device | str | int | None = None,
    detach: bool = True,
    ignore_parameters: Collection[str] | None = None,
    ignore_modules: Collection[str] | None = None,
) -> Module

Clone a module, copying its parameters, buffers, submodules, and hooks.

Note

When you call clone_module on a module which contains GPU tensors, those tensors will be loaded to GPU by default. You can call clone_module(.., device='cpu') to avoid GPU RAM surge when loading a model checkpoint.

Caution

Specifying ignore_parameters or ignore_modules will remove the specified parameters or submodules from the cloned module. This will usually break the cloned module's functionality, unless you modify the clone module's forward in some way. Use these arguments with caution.

Warning

If a module has duplicate parameters/submodules (i.e. a single parameter/submodule that can be accessed by multiple names) when specifying ignore_parameters or ignore_modules, the parameter or submodule will be removed only under the specified name(s).

Warning

If a module has a hook that uses a closure with a reference to the original module, the cloned module's hook will still reference the original module. In this case, it is probably better to temporarily remove the hook, clone the module, and then re-attach the hook to the clone.

Parameters:

Name Type Description Default

module

Module

The module to clone.

required

device

device | str | int | None

The device to cast the cloned module to. If None, the cloned module will be cast to the same device as the original module.

None

detach

bool

Whether to detach the clone module's parameters from the computational graph. If False, then this operation can be backpropagated through. If True, the clone module will be disconnected from the computational graph, but a more memory-efficient strategy will be used to clone the module, avoiding large memory surges on the source device.

True

ignore_parameters

Collection[str] | None

A list of parameter names to ignore when cloning the module. These parameter names should be relative to the module.

None

ignore_modules

Collection[str] | None

A list of submodule names to ignore when cloning the module.

None

Returns:

Type Description
Module

The cloned module.

Examples:

Cloning a module to the same device:

>>> module = nn.Linear(10, 10)
>>> cloned_module = clone_module(module)

The cloned module is not the same object as the original module:

>>> cloned_module is module
False
>>> cloned_module.weight is module.weight
False
>>> cloned_module.bias is module.bias
False

The cloned module, does, however, have parameters with the same values as the original module:

>>> torch.allclose(cloned_module.weight, module.weight)
True
>>> torch.allclose(cloned_module.bias, module.bias)
True

Cloning a module to a different device:

>>> cloned_module = clone_module(module, device="cpu")
>>> cloned_module.weight.device == torch.device("cpu")
True

Cloning a module while ignoring parameters:

>>> class ComplexModule(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear1 = nn.Linear(10, 10)
...         self.relu = nn.ReLU()
...         self.linear2 = nn.Linear(10, 10)
>>> module = ComplexModule()
>>> cloned_module_without_linear1_params = clone_module(
...     module, ignore_parameters=["linear1.weight", "linear1.bias"]
... )
>>> assert cloned_module_without_linear1_params.linear1.weight is None
>>> assert cloned_module_without_linear1_params.linear1.bias is None
>>>
>>> cloned_module_without_linear1_modules = clone_module(
...     module, ignore_modules=["linear1"]
... )
>>> assert not hasattr(cloned_module_without_linear1_modules, "linear1")

Added in version 0.64.0.

collect_devices

collect_devices(
    value: Tensor
    | dict[Any, Any]
    | UserDict[Any, Any]
    | Iterable[Any]
    | Any,
) -> set[torch.device]

Collect all devices in the given value.

Parameters:

Name Type Description Default

value

Tensor | dict[Any, Any] | UserDict[Any, Any] | Iterable[Any] | Any

The value to recursively collect devices from.

required

Returns:

Type Description
set[torch.device]

The set of all devices in the given value.

collect_floating_point_dtypes

collect_floating_point_dtypes(
    value: Tensor
    | dict[Any, Any]
    | UserDict[Any, Any]
    | Iterable[Any]
    | Any,
) -> set[torch.dtype]

Collect all floating point dtypes in the given value.

Parameters:

Name Type Description Default

value

Tensor | dict[Any, Any] | UserDict[Any, Any] | Iterable[Any] | Any

The value to recursively collect floating point dtypes from.

required

Returns:

Type Description
set[torch.dtype]

The set of all floating point dtypes in the given value.

eval_mode

eval_mode(module: Module) -> Generator[None]

Put the entire model into eval mode and disable gradient calculation and upon exiting, restore the original training mode and re-enable gradient calculation.

Parameters:

Name Type Description Default

module

Module

The module to recursively set the training mode of.

required

freeze

freeze(model: Module) -> None

Recursively freeze all parameters, enable eval mode, and set affine and track_running_stats to False for_BatchNorm` submodules.

Examples:

Freeze a model:

>>> model = nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
>>> freeze(model)

The model is now in eval mode:

>>> model[0].weight.requires_grad
False
>>> model[0].bias.requires_grad
False
>>> model[1].weight.requires_grad
False
>>> model[1].bias.requires_grad
False

The model's batch norm submodule is also put in eval mode, and its affine and track_running_stats attributes are set to False:

>>> model[1].affine
False
>>> model[1].track_running_stats
False

Parameters:

Name Type Description Default

model

Module

The model to freeze.

required

hash_tensor_data

hash_tensor_data(tensor: Tensor) -> int

Compute the hash of the tensor's data represented as a string.

Note

Since 0 and -0 have different byte representations, they will produce different hash values.

Parameters:

Name Type Description Default

tensor

Tensor

The tensor whose data to hash.

required

Returns:

Type Description
int

The hash of the tensor's data.

Changed in version 0.76.0: Moved hash_tensor to its own utility and renamed for clarity that the hash operates on the tensor data.

infer_minimal_submodules

infer_minimal_submodules(
    module: Module,
    *module_forward_args: Any,
    **module_forward_kwargs: Any,
) -> list[str]

Infer the minimal set of submodule names required for a forward pass.

For most modules, this function will just return a list of all of the names of the submodules of the given module. This is because most modules will require all of their submodules to be called in order to perform a forward pass. However, some modules may be designed to only require a subset of their submodules to be called in order to perform a forward pass. One notable example is the TruncatedModule class, which is used to truncate the forward pass of a module after a certain submodule has been called.

Note

This internally calls the module's forward method, so it is not guaranteed to be side-effect free. However, this function does not itself directly modify the module, its state, or cause any side effects.

Note

This function assumes that the module has a static computational graph (i.e. which submodules are called does not depend on the input). This function returns only the names of the submodules required for the given inputs. Any dynamic graph behavior will not be captured.

Parameters:

Name Type Description Default

module

Module

The module to infer the minimal submodules for.

required

*module_forward_args

Any

The positional arguments to pass to the module's forward method.

required

**module_forward_kwargs

Any

The keyword arguments to pass to the module's forward method.

required

Returns:

Type Description
list[str]

The minimal set of submodule names required for a forward pass.

Examples:

Define a module that only requires a subset of its submodules to be called in order to perform a forward pass:

>>> from stainedglass_core import model as sg_model
>>> full_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
>>> truncated_model = sg_model.TruncatedModule(
...     full_model, truncation_point=full_model[1]
... )
>>> input = torch.rand(10)

The full model requires all of its submodules to be called in order to perform a forward pass:

>>> infer_minimal_submodules(full_model, input)
['0', '1', '2']

The truncated model only requires the first two submodules to be called in order to perform a forward pass:

>>> infer_minimal_submodules(truncated_model, input)
['module.0', 'module.1']

Added in version 0.70.0.

min_max_normalize

min_max_normalize(tensor: Tensor) -> <class 'torch.Tensor'>

Normalize the given tensor by subtracting the minimum value and dividing by the range of its minimum value to its maximum value. This is equivalent to torchvision.utils.make_grid(normalize=True). If the range is zero, the tensor is clamped to [0, 1].

Examples:

>>> tensor = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.float32)
>>> min_max_normalize(tensor)
tensor([[0.0000, 0.2000, 0.4000],
        [0.6000, 0.8000, 1.0000]])

Parameters:

Name Type Description Default

tensor

Tensor

The tensor to normalize.

required

Returns:

Type Description
<class 'torch.Tensor'>

The normalized tensor.

order_tensors_by_norm

order_tensors_by_norm(
    tensor_a: Tensor, tensor_b: Tensor, **kwargs: Any
) -> tuple[torch.Tensor, torch.Tensor]

Return the supplied tensors as a tuple, sorted by their norm.

By default, the vector 2-norm is used, however by supplying kwargs, other norms are possible to use.

Note

If the norms are equal, the order of the supplied tensors is preserved. I.e. the function is idempotent/the sort is stable.

Parameters:

Name Type Description Default

tensor_a

Tensor

A tensor to compare the norm of.

required

tensor_b

Tensor

A tensor to compare the norm of.

required

**kwargs

Any

The variable length keyword arguments to supply to the underlying torch.linalg.vector_norm call. Useful for using other norms than the default vector 2-norm.

required

Returns:

Type Description
tuple[torch.Tensor, torch.Tensor]

A tuple of the original tensors, where the first element has norm less than or equal to the second element.

pad_sequence

pad_sequence(
    sequences: list[Tensor] | tuple[Tensor, ...],
    padding_value: float = 0.0,
    padding_side: Literal["left", "right"] | str = "right",
    pad_to_multiple_of: int | None = None,
    truncation_side: Literal["left", "right"]
    | str = "right",
    max_length: int | None = None,
) -> torch.Tensor

Pack a sequence of variable length tensors into a single tensor, padding them to the same length on the specified side.

Parameters:

Name Type Description Default

sequences

list[Tensor] | tuple[Tensor, ...]

A list or tuple 1D tensors of shape (sequence_length,).

required

padding_value

float

The value to use for padding.

0.0

padding_side

Literal['left', 'right'] | str

The side on which to pad the sequences. Either 'left' or 'right'.

'right'

pad_to_multiple_of

int | None

The multiple length to pad the sequences to.

None

truncation_side

Literal['left', 'right'] | str

The side to truncate the sequences if max_length is specified. Either 'left' or 'right'.

'right'

max_length

int | None

The maximum length after which sequences will be truncated.

None

Raises:

Type Description
ValueError

If padding_side is not one of 'left' or 'right'.

ValueError

If truncation_side is not one of 'left' or 'right'.

ValueError

If max_length is not a multiple of pad_to_multiple_of.

Returns:

Type Description
torch.Tensor

A 2D tensor of shape (len(sequences), max(len(seq) for seq in sequences)), where the sequences are truncated to max_length if

torch.Tensor

provided, and padded to pad_to_multiple_of if provided, with padding_value on padding_side.

Added in version 0.84.0.

project_out_smaller_tensor

project_out_smaller_tensor(tensor_a: Tensor, tensor_b: Tensor) -> <class 'torch.Tensor'>

Calculate the larger of the two supplied gradients, and then returns the larger gradient with the component of the smaller gradient removed.

This function is symmetric.

Parameters:

Name Type Description Default

tensor_a

Tensor

A tensor to compare and project.

required

tensor_b

Tensor

A tensor to compare and project.

required

Returns:

Type Description
<class 'torch.Tensor'>

The larger of the supplied tensors with the component of the smaller tensor removed.

scaled_kaiming_uniform_

scaled_kaiming_uniform_(
    t: Tensor, initialization_scale: float
) -> None

Initialize a tensor with a Kaiming distribution scaled by initialization_scale.

Parameters:

Name Type Description Default

t

Tensor

The tensor to initialize.

required

initialization_scale

float

The amount to scale the initialization by.

required

set_module_attr

set_module_attr(
    module: Module, target: str, value: Module | Tensor
) -> None

Set a submodule, buffer, or parameter by its dot-delimited name, e.g. 'features.0.conv'.

Inspired by nn.Module.get_parameter.

Examples:

Set a submodule:

>>> model = torch.nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
>>> target_module = torch.nn.BatchNorm1d(10, affine=False)
>>> model[1] is target_module
False
>>> set_module_attr(model, "1", target_module)
>>> model[1] is target_module
True

Parameters:

Name Type Description Default

module

Module

The source module.

required

target

str

The dot-delimited name of the submodule, buffer, or parameter to set.

required

value

Module | Tensor

The value to set.

required

Raises:

Type Description
AttributeError

If the target string references an invalid path.

temporarily_offload

temporarily_offload(module: Module) -> Generator[None]

Temporarily offload the module's parameters and buffers to the cpu.

Parameters:

Name Type Description Default

module

Module

The module to offload.

required

temporarily_remove_hooks

temporarily_remove_hooks(module: Module) -> Generator[None]

Temporarily remove all hooks from a module and its submodules when inside this context, and restore the hooks to their original state upon exiting.

Note

Adding new hooks to or removing existing ones from (using their RemovableHandle.remove() methods) the module (or its submodules) has undefined behavior while inside this context. In particular a newly added hook will only be accessible during this context.

Parameters:

Name Type Description Default

module

Module

The module to remove hooks from.

required

Examples:

Temporarily remove hooks from a module:

>>> model = nn.Sequential(nn.Linear(10, 10))
>>> model.register_forward_hook(
...     lambda module, input, output: print("Forward hook called")
... )
<...RemovableHandle object at ...>
>>> model(torch.rand(10))
Forward hook called
tensor(...)
>>> with temporarily_remove_hooks(model):
...     model(torch.rand(10))
tensor(...)
>>> model(torch.rand(10))
Forward hook called
tensor(...)

Changed in version 0.60.0: Hooks with kwargs are stored in a different attribute which only exists for pytorch versions 2.0 and greater.

train_mode

train_mode(
    module: Module, mode: bool = True
) -> Generator[None]

Put the entire model into the desired training mode and upon exiting, restore the original training mode.

If mode is False, gradient calculation is disabled, otherwise is is not affected.

Parameters:

Name Type Description Default

module

Module

The module to recursively set the training mode of.

required

mode

bool

whether to set training mode (True) or evaluation mode (False).

True

vector_projection

vector_projection(tensor_a: Tensor, tensor_b: Tensor) -> <class 'torch.Tensor'>

Compute the (vector) projection of tensor_a onto tensor_b.

\(\text{proj}_{b}a := a\cdot\hat{b}\hat{b}\)

Parameters:

Name Type Description Default

tensor_a

Tensor

The tensor to project.

required

tensor_b

Tensor

The tensor to project onto.

required

Returns:

Type Description
<class 'torch.Tensor'>

tensor_a projected onto tensor_b.