Skip to content

truncated_module

Classes:

Name Description
TruncatedModule

A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.

TruncationExecutionFinished

Special exception raised by a TruncatedModule

TruncatedModule

Bases: Module, Generic[ModuleT]

A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.

This truncation happens by temporarily adding a hook to the truncation point that raises a TruncationExecutionFinished exception which is then caught by the TruncatedModule forward and the output of the truncation point is returned.

Examples:

Instantiating a TruncatedModule with a Binary Classification model and a truncation point:

>>> model = torch.nn.Sequential(
...     torch.nn.Linear(10, 20),
...     torch.nn.ReLU(),
...     torch.nn.Linear(20, 30),
...     torch.nn.ReLU(),
...     torch.nn.Linear(30, 40),
...     torch.nn.ReLU(),
...     torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)

Using the TruncatedModule to get the output of the truncation point:

>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)

The base model of the TruncatedModule is completely unaffected by the truncation:

>>> base_output = model(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape

The base model is also accessible directly through the module attribute of the TruncatedModule:

>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape

Added in version 0.59.0.

Methods:

Name Description
__init__

Initialize the TruncatedModule with the provided module and truncation point.

forward

Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.

lazy_register_truncation_hook

Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.

truncation_hook

Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.

__init__

__init__(module: ModuleT, truncation_point: Module) -> None

Initialize the TruncatedModule with the provided module and truncation point.

Parameters:

Name Type Description Default

module

ModuleT

The module to wrap.

required

truncation_point

Module

The submodule of the provided module at which to interrupt the forward pass.

required

Raises:

Type Description
ValueError

If the truncation point is not a submodule of the provided module.

forward

forward(*args: Any, **kwargs: Any) -> Any

Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.

Parameters:

Name Type Description Default

*args

Any

The positional arguments to pass to the wrapped module.

required

**kwargs

Any

The keyword arguments to pass to the wrapped module.

required

Returns:

Type Description
Any

The output of the truncation point submodule.

Raises:

Type Description
HookNotCalledError

If the truncation hook is not called, meaning the truncation point was not reached.

lazy_register_truncation_hook

lazy_register_truncation_hook() -> _HandlerWrapper

Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.

Returns:

Type Description
_HandlerWrapper

A handler wrapper that contains the hook that was added to the truncation point.

truncation_hook staticmethod

truncation_hook(
    truncation_point: Module, args: Any, output: Tensor
) -> NoReturn

Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.

Parameters:

Name Type Description Default

truncation_point

Module

The truncation point submodule. Unused.

required

args

Any

The arguments passed to the truncation point. Unused.

required

output

Tensor

The output of the truncation point. This is the output that will be returned by the TruncatedModule.

required

Raises:

Type Description
TruncationExecutionFinished

Always, in order to interrupt the wrapped model's forward method.

TruncationExecutionFinished

Bases: Exception

Special exception raised by a TruncatedModule itself that the truncation point has been reached; it is expected to be caught instead of raised to the user.

Added in version 0.59.0.

Methods:

Name Description
__init__

Initialize the exception with a message and the output of the truncation point.

__init__

__init__(message: str, intercepted_output: object) -> None

Initialize the exception with a message and the output of the truncation point.

Parameters:

Name Type Description Default

message

str

The message to display when the exception is raised.

required

intercepted_output

object

The output of the truncation point that caused the exception to be raised.

required