model
Modules:
Name | Description |
---|---|
noisy_model |
|
noisy_transformer_masking_model |
|
truncated_module |
|
Classes:
Name | Description |
---|---|
NoisyModel |
Applies a |
TruncatedModule |
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached. |
NoisyModel
¶
Bases: Module
, Generic[ModuleT, NoiseLayerP, NoiseLayerT]
Applies a BaseNoiseLayer
to a model input Tensor
or a submodule output Tensor
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Callable[NoiseLayerP, NoiseLayerT]
|
The type of |
required |
|
ModuleT
|
The model to apply the |
required |
|
args
|
Positional arguments to |
()
|
|
str | None
|
The name of the |
None
|
|
str | None
|
The name of the |
None
|
|
kwargs
|
Keyword arguments to |
{}
|
Raises:
Type | Description |
---|---|
ValueError
|
If both |
ValueError
|
If neither |
Methods:
Name | Description |
---|---|
forward |
Call the |
reset_parameters |
Reinitialize parameters and buffers. |
Attributes:
Name | Type | Description |
---|---|---|
target_layer |
Module
|
The |
target_parameter |
str | None
|
The name of the |
target_parameter_index |
int
|
The index of the |
target_layer
property
¶
target_layer: Module
The base_model
submodule whose output Tensor
to transform.
Raises:
Type | Description |
---|---|
ValueError
|
If |
target_parameter
property
¶
target_parameter: str | None
The name of the base_model
input Tensor
argument to transform when target_layer
is None
.
target_parameter_index
cached
property
¶
target_parameter_index: int
The index of the base_model
input Tensor
argument to transform when target_layer
is None
.
forward
¶
Call the base_model
, applying the noise_layer
to the target_parameter
or target_layer
output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Any
|
Positional arguments to |
required |
|
Tensor | None
|
An optional mask that selects the elements of the |
None
|
|
Any
|
Keyword arguments to |
required |
Returns:
Type | Description |
---|---|
Any
|
The result of |
reset_parameters
¶
Reinitialize parameters and buffers.
This method is useful for initializing tensors created on the meta device.
TruncatedModule
¶
Bases: Module
, Generic[ModuleT]
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.
This truncation happens by temporarily adding a hook to the truncation point that raises a
TruncationExecutionFinished
exception which is then caught by
the TruncatedModule
forward and the output of the truncation point is returned.
Examples:
Instantiating a TruncatedModule
with a Binary Classification model and a truncation point:
>>> model = torch.nn.Sequential(
... torch.nn.Linear(10, 20),
... torch.nn.ReLU(),
... torch.nn.Linear(20, 30),
... torch.nn.ReLU(),
... torch.nn.Linear(30, 40),
... torch.nn.ReLU(),
... torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)
Using the TruncatedModule
to get the output of the truncation point:
>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)
The base model of the TruncatedModule
is completely unaffected by the truncation:
>>> base_output = model(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
The base model is also accessible directly through the module
attribute of the TruncatedModule
:
>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
Added in version 0.59.0.
Methods:
Name | Description |
---|---|
__init__ |
Initialize the |
forward |
Forward pass of the |
lazy_register_truncation_hook |
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached. |
truncation_hook |
Intercept the output of the truncation point and raise a |
__init__
¶
__init__(module: ModuleT, truncation_point: Module) -> None
Initialize the TruncatedModule
with the provided module and truncation point.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ModuleT
|
The module to wrap. |
required |
|
Module
|
The submodule of the provided module at which to interrupt the forward pass. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If the truncation point is not a submodule of the provided module. |
forward
¶
Forward pass of the TruncatedModule
that interrupts the forward pass when the truncation point is reached.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Any
|
The positional arguments to pass to the wrapped module. |
required |
|
Any
|
The keyword arguments to pass to the wrapped module. |
required |
Returns:
Type | Description |
---|---|
Any
|
The output of the truncation point submodule. |
Raises:
Type | Description |
---|---|
HookNotCalledError
|
If the truncation hook is not called, meaning the truncation point was not reached. |
lazy_register_truncation_hook
¶
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.
Returns:
Type | Description |
---|---|
_HandlerWrapper
|
A handler wrapper that contains the hook that was added to the truncation point. |
truncation_hook
staticmethod
¶
Intercept the output of the truncation point and raise a TruncationExecutionFinished
exception containing that output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The truncation point submodule. Unused. |
required |
|
Any
|
The arguments passed to the truncation point. Unused. |
required |
|
Tensor
|
The output of the truncation point. This is the output that will be returned by the |
required |
Raises:
Type | Description |
---|---|
TruncationExecutionFinished
|
Always, in order to interrupt the wrapped model's |