truncated_module
Classes:
| Name | Description | 
|---|---|
| TruncatedModule | A module that wraps another module that interrupts the forward pass when a specified truncation point is reached. | 
| TruncationExecutionFinished | Special exception raised by a  | 
    
              Bases: Module, Generic[ModuleT]
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.
This truncation happens by temporarily adding a hook to the truncation point that raises a
TruncationExecutionFinished exception which is then caught by
the TruncatedModule forward and the output of the truncation point is returned.
Examples:
Instantiating a TruncatedModule with a Binary Classification model and a truncation point:
>>> model = torch.nn.Sequential(
...     torch.nn.Linear(10, 20),
...     torch.nn.ReLU(),
...     torch.nn.Linear(20, 30),
...     torch.nn.ReLU(),
...     torch.nn.Linear(30, 40),
...     torch.nn.ReLU(),
...     torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)
Using the TruncatedModule to get the output of the truncation point:
>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)
The base model of the TruncatedModule is completely unaffected by the truncation:
>>> base_output = model(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape
The base model is also accessible directly through the module attribute of the TruncatedModule:
>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape
Added in version 0.59.0.
Methods:
| Name | Description | 
|---|---|
| __init__ | Initialize the  | 
| forward | Forward pass of the  | 
| lazy_register_truncation_hook | Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached. | 
| truncation_hook | Intercept the output of the truncation point and raise a  | 
__init__(module: ModuleT, truncation_point: Module) -> None
Initialize the TruncatedModule with the provided module and truncation point.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | ModuleT | The module to wrap. | required | 
|                    | Module | The submodule of the provided module at which to interrupt the forward pass. | required | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If the truncation point is not a submodule of the provided module. | 
    Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Any | The positional arguments to pass to the wrapped module. | required | 
|                    | Any | The keyword arguments to pass to the wrapped module. | required | 
Returns:
| Type | Description | 
|---|---|
| Any | The output of the truncation point submodule. | 
Raises:
| Type | Description | 
|---|---|
| HookNotCalledError | If the truncation hook is not called, meaning the truncation point was not reached. | 
    Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.
Returns:
| Type | Description | 
|---|---|
| _HandlerWrapper | A handler wrapper that contains the hook that was added to the truncation point. | 
staticmethod
  
¶
    Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Module | The truncation point submodule. Unused. | required | 
|                    | Any | The arguments passed to the truncation point. Unused. | required | 
|                    | Tensor | The output of the truncation point. This is the output that will be returned by the  | required | 
Raises:
| Type | Description | 
|---|---|
| TruncationExecutionFinished | Always, in order to interrupt the wrapped model's  | 
    
              Bases: Exception
Special exception raised by a TruncatedModule
itself that the truncation point has been reached; it is expected to be caught instead of raised to the user.
Added in version 0.59.0.
Methods:
| Name | Description | 
|---|---|
| __init__ | Initialize the exception with a message and the output of the truncation point. | 
__init__(message: str, intercepted_output: object) -> None
Initialize the exception with a message and the output of the truncation point.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | str | The message to display when the exception is raised. | required | 
|                    | object | The output of the truncation point that caused the exception to be raised. | required |