model
Modules:
| Name | Description |
|---|---|
noisy_model |
|
noisy_transformer_masking_model |
|
peft_noisy_transformer_masking_model |
|
truncated_module |
|
Classes:
| Name | Description |
|---|---|
NoisyModel |
Applies a |
TruncatedModule |
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached. |
NoisyModel
¶
Bases: Module, Generic[ModuleT, NoiseLayerP, NoiseLayerT]
Applies a BaseNoiseLayer to a model input Tensor or a submodule output Tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Callable[NoiseLayerP, NoiseLayerT]
|
The type of |
required |
|
ModuleT
|
The model to apply the |
required |
|
args
|
Positional arguments to |
()
|
|
str | None
|
The name of the |
None
|
|
str | None
|
The name of the |
None
|
|
kwargs
|
Keyword arguments to |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If both |
ValueError
|
If neither |
Methods:
| Name | Description |
|---|---|
distillation_context |
Prepare the base model to facilitate distillation training by applying losses over the transformed and non-transformed |
forward |
Call the |
reset_parameters |
Reinitialize parameters and buffers. |
Attributes:
| Name | Type | Description |
|---|---|---|
target_layer |
Module
|
The |
target_parameter |
str | None
|
The name of the |
target_parameter_index |
int
|
The index of the |
target_layer
property
¶
target_layer: Module
The base_model submodule whose output Tensor to transform.
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
target_parameter
property
¶
target_parameter: str | None
The name of the base_model input Tensor argument to transform when target_layer is None.
target_parameter_index
cached
property
¶
target_parameter_index: int
The index of the base_model input Tensor argument to transform when target_layer is None.
distillation_context
¶
Prepare the base model to facilitate distillation training by applying losses over the transformed and non-transformed activations.
Note
This context manager assumes that the output of the base_model is a mutable mapping with a logits key.
Returns:
| Type | Description |
|---|---|
contextlib.ExitStack
|
A context manager that detaches the hooks when exited. |
Added in version v2.6.0.
forward
¶
Call the base_model, applying the noise_layer to the target_parameter or target_layer output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
Positional arguments to |
required |
|
Tensor | None
|
An optional mask that selects the elements of the |
None
|
|
Any
|
Keyword arguments to |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The result of |
reset_parameters
¶
Reinitialize parameters and buffers.
This method is useful for initializing tensors created on the meta device.
TruncatedModule
¶
Bases: Module, Generic[ModuleT]
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.
This truncation happens by temporarily adding a hook to the truncation point that raises a
TruncationExecutionFinished exception which is then caught by
the TruncatedModule forward and the output of the truncation point is returned.
Examples:
Instantiating a TruncatedModule with a Binary Classification model and a truncation point:
>>> model = torch.nn.Sequential(
... torch.nn.Linear(10, 20),
... torch.nn.ReLU(),
... torch.nn.Linear(20, 30),
... torch.nn.ReLU(),
... torch.nn.Linear(30, 40),
... torch.nn.ReLU(),
... torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)
Using the TruncatedModule to get the output of the truncation point:
>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)
The base model of the TruncatedModule is completely unaffected by the truncation:
>>> base_output = model(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
The base model is also accessible directly through the module attribute of the TruncatedModule:
>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
Added in version 0.59.0.
Methods:
| Name | Description |
|---|---|
__init__ |
Initialize the |
forward |
Forward pass of the |
lazy_register_truncation_hook |
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached. |
truncation_hook |
Intercept the output of the truncation point and raise a |
__init__
¶
__init__(module: ModuleT, truncation_point: Module) -> None
Initialize the TruncatedModule with the provided module and truncation point.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
ModuleT
|
The module to wrap. |
required |
|
Module
|
The submodule of the provided module at which to interrupt the forward pass. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the truncation point is not a submodule of the provided module. |
forward
¶
Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
The positional arguments to pass to the wrapped module. |
required |
|
Any
|
The keyword arguments to pass to the wrapped module. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The output of the truncation point submodule. |
Raises:
| Type | Description |
|---|---|
HookNotCalledError
|
If the truncation hook is not called, meaning the truncation point was not reached. |
lazy_register_truncation_hook
¶
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.
Returns:
| Type | Description |
|---|---|
_HandlerWrapper
|
A handler wrapper that contains the hook that was added to the truncation point. |
truncation_hook
staticmethod
¶
Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Module
|
The truncation point submodule. Unused. |
required |
|
Any
|
The arguments passed to the truncation point. Unused. |
required |
|
Tensor
|
The output of the truncation point. This is the output that will be returned by the |
required |
Raises:
| Type | Description |
|---|---|
TruncationExecutionFinished
|
Always, in order to interrupt the wrapped model's |