model
Modules:
| Name | Description | 
|---|---|
| noisy_model |  | 
| noisy_transformer_masking_model |  | 
| peft_noisy_transformer_masking_model |  | 
| truncated_module |  | 
Classes:
| Name | Description | 
|---|---|
| NoisyModel | Applies a  | 
| TruncatedModule | A module that wraps another module that interrupts the forward pass when a specified truncation point is reached. | 
    
              Bases: Module, Generic[ModuleT, NoiseLayerP, NoiseLayerT]
Applies a BaseNoiseLayer to a model input Tensor or a submodule output Tensor.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Callable[NoiseLayerP, NoiseLayerT] | The type of  | required | 
|                    | ModuleT | The model to apply the  | required | 
|                    | args | Positional arguments to  | () | 
|                    | str | None | The name of the  | None | 
|                    | str | None | The name of the  | None | 
|                    | kwargs | Keyword arguments to  | {} | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If both  | 
| ValueError | If neither  | 
Methods:
| Name | Description | 
|---|---|
| forward | Call the  | 
| reset_parameters | Reinitialize parameters and buffers. | 
Attributes:
| Name | Type | Description | 
|---|---|---|
| target_layer | Module | The  | 
| target_parameter | str | None | The name of the  | 
| target_parameter_index | int | The index of the  | 
property
  
¶
target_layer: Module
The base_model submodule whose output Tensor to transform.
Raises:
| Type | Description | 
|---|---|
| ValueError | If  | 
property
  
¶
target_parameter: str | None
The name of the base_model input Tensor argument to transform when target_layer is None.
cached
      property
  
¶
target_parameter_index: int
The index of the base_model input Tensor argument to transform when target_layer is None.
    Call the base_model, applying the noise_layer to the target_parameter or target_layer output.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Any | Positional arguments to  | required | 
|                    | Tensor | None | An optional mask that selects the elements of the  | None | 
|                    | Any | Keyword arguments to  | required | 
Returns:
| Type | Description | 
|---|---|
| Any | The result of  | 
    Reinitialize parameters and buffers.
This method is useful for initializing tensors created on the meta device.
    
              Bases: Module, Generic[ModuleT]
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.
This truncation happens by temporarily adding a hook to the truncation point that raises a
TruncationExecutionFinished exception which is then caught by
the TruncatedModule forward and the output of the truncation point is returned.
Examples:
Instantiating a TruncatedModule with a Binary Classification model and a truncation point:
>>> model = torch.nn.Sequential(
...     torch.nn.Linear(10, 20),
...     torch.nn.ReLU(),
...     torch.nn.Linear(20, 30),
...     torch.nn.ReLU(),
...     torch.nn.Linear(30, 40),
...     torch.nn.ReLU(),
...     torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)
Using the TruncatedModule to get the output of the truncation point:
>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)
The base model of the TruncatedModule is completely unaffected by the truncation:
>>> base_output = model(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape
The base model is also accessible directly through the module attribute of the TruncatedModule:
>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2)  # Binary classification output shape
Added in version 0.59.0.
Methods:
| Name | Description | 
|---|---|
| __init__ | Initialize the  | 
| forward | Forward pass of the  | 
| lazy_register_truncation_hook | Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached. | 
| truncation_hook | Intercept the output of the truncation point and raise a  | 
__init__(module: ModuleT, truncation_point: Module) -> None
Initialize the TruncatedModule with the provided module and truncation point.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | ModuleT | The module to wrap. | required | 
|                    | Module | The submodule of the provided module at which to interrupt the forward pass. | required | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If the truncation point is not a submodule of the provided module. | 
    Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Any | The positional arguments to pass to the wrapped module. | required | 
|                    | Any | The keyword arguments to pass to the wrapped module. | required | 
Returns:
| Type | Description | 
|---|---|
| Any | The output of the truncation point submodule. | 
Raises:
| Type | Description | 
|---|---|
| HookNotCalledError | If the truncation hook is not called, meaning the truncation point was not reached. | 
    Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.
Returns:
| Type | Description | 
|---|---|
| _HandlerWrapper | A handler wrapper that contains the hook that was added to the truncation point. | 
staticmethod
  
¶
    Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Module | The truncation point submodule. Unused. | required | 
|                    | Any | The arguments passed to the truncation point. Unused. | required | 
|                    | Tensor | The output of the truncation point. This is the output that will be returned by the  | required | 
Raises:
| Type | Description | 
|---|---|
| TruncationExecutionFinished | Always, in order to interrupt the wrapped model's  |