truncated_module
Classes:
Name | Description |
---|---|
TruncatedModule |
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached. |
TruncationExecutionFinished |
Special exception raised by a |
TruncatedModule
¶
Bases: Module
, Generic[ModuleT]
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.
This truncation happens by temporarily adding a hook to the truncation point that raises a
TruncationExecutionFinished
exception which is then caught by
the TruncatedModule
forward and the output of the truncation point is returned.
Examples:
Instantiating a TruncatedModule
with a Binary Classification model and a truncation point:
>>> model = torch.nn.Sequential(
... torch.nn.Linear(10, 20),
... torch.nn.ReLU(),
... torch.nn.Linear(20, 30),
... torch.nn.ReLU(),
... torch.nn.Linear(30, 40),
... torch.nn.ReLU(),
... torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)
Using the TruncatedModule
to get the output of the truncation point:
>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)
The base model of the TruncatedModule
is completely unaffected by the truncation:
>>> base_output = model(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
The base model is also accessible directly through the module
attribute of the TruncatedModule
:
>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
Added in version 0.59.0.
Methods:
Name | Description |
---|---|
__init__ |
Initialize the |
forward |
Forward pass of the |
lazy_register_truncation_hook |
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached. |
truncation_hook |
Intercept the output of the truncation point and raise a |
__init__
¶
__init__(module: ModuleT, truncation_point: Module) -> None
Initialize the TruncatedModule
with the provided module and truncation point.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ModuleT
|
The module to wrap. |
required |
|
Module
|
The submodule of the provided module at which to interrupt the forward pass. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If the truncation point is not a submodule of the provided module. |
forward
¶
Forward pass of the TruncatedModule
that interrupts the forward pass when the truncation point is reached.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Any
|
The positional arguments to pass to the wrapped module. |
required |
|
Any
|
The keyword arguments to pass to the wrapped module. |
required |
Returns:
Type | Description |
---|---|
Any
|
The output of the truncation point submodule. |
Raises:
Type | Description |
---|---|
HookNotCalledError
|
If the truncation hook is not called, meaning the truncation point was not reached. |
lazy_register_truncation_hook
¶
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.
Returns:
Type | Description |
---|---|
_HandlerWrapper
|
A handler wrapper that contains the hook that was added to the truncation point. |
truncation_hook
staticmethod
¶
Intercept the output of the truncation point and raise a TruncationExecutionFinished
exception containing that output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The truncation point submodule. Unused. |
required |
|
Any
|
The arguments passed to the truncation point. Unused. |
required |
|
Tensor
|
The output of the truncation point. This is the output that will be returned by the |
required |
Raises:
Type | Description |
---|---|
TruncationExecutionFinished
|
Always, in order to interrupt the wrapped model's |
TruncationExecutionFinished
¶
Bases: Exception
Special exception raised by a TruncatedModule
itself that the truncation point has been reached; it is expected to be caught instead of raised to the user.
Added in version 0.59.0.
Methods:
Name | Description |
---|---|
__init__ |
Initialize the exception with a message and the output of the truncation point. |
__init__
¶
__init__(message: str, intercepted_output: object) -> None
Initialize the exception with a message and the output of the truncation point.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
str
|
The message to display when the exception is raised. |
required |
|
object
|
The output of the truncation point that caused the exception to be raised. |
required |