Skip to content

model

Modules:

Name Description
noisy_model
noisy_transformer_masking_model
truncated_module

Classes:

Name Description
NoisyModel

Applies a BaseNoiseLayer to a model input Tensor or a submodule output Tensor.

TruncatedModule

A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.

NoisyModel

Bases: Module, Generic[ModuleT, NoiseLayerP, NoiseLayerT]

Applies a BaseNoiseLayer to a model input Tensor or a submodule output Tensor.

Parameters:

Name Type Description Default

noise_layer_class

Callable[NoiseLayerP, NoiseLayerT]

The type of BaseNoiseLayer to apply.

required

base_model

ModuleT

The model to apply the BaseNoiseLayer to.

required

*args

args

Positional arguments to noise_layer_class.

()

target_layer

str | None

The name of the base_model submodule (e.g. 'features.0.conv.1.2') whose output Tensor to transform. If provided, target_parameter must be None.

None

target_parameter

str | None

The name of the base_model input Tensor argument to transform. If provided, target_layer must be None.

None

**kwargs

kwargs

Keyword arguments to noise_layer_class.

{}

Raises:

Type Description
ValueError

If both target_layer and target_parameter are None.

ValueError

If neither target_layer nor target_parameter are None.

Methods:

Name Description
forward

Call the base_model, applying the noise_layer to the target_parameter or target_layer output.

reset_parameters

Reinitialize parameters and buffers.

Attributes:

Name Type Description
target_layer Module

The base_model submodule whose output Tensor to transform.

target_parameter str | None

The name of the base_model input Tensor argument to transform when target_layer is None.

target_parameter_index int

The index of the base_model input Tensor argument to transform when target_layer is None.

target_layer property

target_layer: Module

The base_model submodule whose output Tensor to transform.

Raises:

Type Description
ValueError

If _target_layer cannot be found as a submodule of base_model.

target_parameter property

target_parameter: str | None

The name of the base_model input Tensor argument to transform when target_layer is None.

target_parameter_index cached property

target_parameter_index: int

The index of the base_model input Tensor argument to transform when target_layer is None.

forward

forward(
    *args: Any,
    noise_mask: Tensor | None = None,
    **kwargs: Any,
) -> Any

Call the base_model, applying the noise_layer to the target_parameter or target_layer output.

Parameters:

Name Type Description Default

*args

Any

Positional arguments to base_model.

required

noise_mask

Tensor | None

An optional mask that selects the elements of the target_parameter or target_layer output to transform. Where the mask is False, the original values of the target are used. If None, the entire target is transformed.

None

**kwargs

Any

Keyword arguments to base_model.

required

Returns:

Type Description
Any

The result of base_model with the noise_layer applied to the target_parameter or target_layer output.

reset_parameters

reset_parameters() -> None

Reinitialize parameters and buffers.

This method is useful for initializing tensors created on the meta device.

TruncatedModule

Bases: Module, Generic[ModuleT]

A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.

This truncation happens by temporarily adding a hook to the truncation point that raises a TruncationExecutionFinished exception which is then caught by the TruncatedModule forward and the output of the truncation point is returned.

Examples:

Instantiating a TruncatedModule with a Binary Classification model and a truncation point:

>>> model = torch.nn.Sequential(
...     torch.nn.Linear(10, 20),
...     torch.nn.ReLU(),
...     torch.nn.Linear(20, 30),
...     torch.nn.ReLU(),
...     torch.nn.Linear(30, 40),
...     torch.nn.ReLU(),
...     torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)

Using the TruncatedModule to get the output of the truncation point:

>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)

The base model of the TruncatedModule is completely unaffected by the truncation:

>>> base_output = model(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape

The base model is also accessible directly through the module attribute of the TruncatedModule:

>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape

Added in version 0.59.0.

Methods:

Name Description
__init__

Initialize the TruncatedModule with the provided module and truncation point.

forward

Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.

lazy_register_truncation_hook

Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.

truncation_hook

Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.

__init__

__init__(module: ModuleT, truncation_point: Module) -> None

Initialize the TruncatedModule with the provided module and truncation point.

Parameters:

Name Type Description Default

module

ModuleT

The module to wrap.

required

truncation_point

Module

The submodule of the provided module at which to interrupt the forward pass.

required

Raises:

Type Description
ValueError

If the truncation point is not a submodule of the provided module.

forward

forward(*args: Any, **kwargs: Any) -> Any

Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.

Parameters:

Name Type Description Default

*args

Any

The positional arguments to pass to the wrapped module.

required

**kwargs

Any

The keyword arguments to pass to the wrapped module.

required

Returns:

Type Description
Any

The output of the truncation point submodule.

Raises:

Type Description
HookNotCalledError

If the truncation hook is not called, meaning the truncation point was not reached.

lazy_register_truncation_hook

lazy_register_truncation_hook() -> _HandlerWrapper

Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.

Returns:

Type Description
_HandlerWrapper

A handler wrapper that contains the hook that was added to the truncation point.

truncation_hook staticmethod

truncation_hook(
    truncation_point: Module, args: Any, output: Tensor
) -> NoReturn

Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.

Parameters:

Name Type Description Default

truncation_point

Module

The truncation point submodule. Unused.

required

args

Any

The arguments passed to the truncation point. Unused.

required

output

Tensor

The output of the truncation point. This is the output that will be returned by the TruncatedModule.

required

Raises:

Type Description
TruncationExecutionFinished

Always, in order to interrupt the wrapped model's forward method.