model
Initialization for core model modules.
Modules:
| Name | Description |
|---|---|
noisy_encoder_transformer_masking_model |
Module containing the |
noisy_model |
Module for creating noisy models by integrating a noise layer. |
noisy_transformer_masking_model |
Module for a noisy transformer model with masking capabilities. |
peft_noisy_encoder_transformer_masking_model |
Module for PEFT-enabled noisy transformer models with masking capabilities. |
peft_noisy_transformer_masking_model |
Module for PEFT-enabled noisy transformer models with masking capabilities. |
truncated_module |
Module for a truncated module that interrupts forward passes. |
Classes:
| Name | Description |
|---|---|
NoisyModel |
Applies a |
TruncatedModule |
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached. |
NoisyModel
¶
Bases: Module, ABC, Generic[ModuleT, NoiseLayerP, NoiseLayerT_co]
Applies a BaseNoiseLayer to a model input Tensor or a submodule output Tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Callable[NoiseLayerP, NoiseLayerT_co]
|
The type of |
required |
|
ModuleT
|
The model to apply the |
required |
|
args
|
Positional arguments to |
()
|
|
str | None
|
The name of the |
None
|
|
str | None
|
The name of the |
None
|
|
kwargs
|
Keyword arguments to |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If both |
ValueError
|
If neither |
Methods:
| Name | Description |
|---|---|
__getstate__ |
Serialize the model to a dictionary. |
__setstate__ |
Deserialize the model from a dictionary. |
deserialize_base_model |
Deserialize the base model from its serialized representation. |
distillation_context |
Prepare the base model to facilitate distillation training by applying losses over the transformed and non-transformed |
forward |
Call the |
reset_parameters |
Reinitialize parameters and buffers. |
serialize_base_model |
Serialize the base model. |
serialize_init_passthrough_kwargs |
Serialize the keyword arguments needed to re-initialize the |
Attributes:
| Name | Type | Description |
|---|---|---|
target_layer |
Module
|
The |
target_parameter |
str | None
|
The name of the |
target_parameter_index |
int
|
The index of the |
target_layer
property
¶
target_layer: Module
The base_model submodule whose output Tensor to transform.
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
target_parameter
property
¶
target_parameter: str | None
The name of the base_model input Tensor argument to transform when target_layer is None.
target_parameter_index
cached
property
¶
target_parameter_index: int
The index of the base_model input Tensor argument to transform when target_layer is None.
__getstate__
¶
Serialize the model to a dictionary.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the model's state, including the base model, noise layer, and state dict. |
Added in version v3.18.0. Add support for serializing NoisyModel instances.
__setstate__
¶
__setstate__(
state: Mapping[str, Any],
trust_remote_code: bool = False,
third_party_model_path: (
str | PathLike[str] | None
) = None,
) -> None
Deserialize the model from a dictionary.
Warning
The state_dict key is considered optional. If it is not present, or is incomplete, the missing parameters will be initialized to
the meta device. Allowing this to be optional enables the NoisyModel parameters to be restored as part
of a larger model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Mapping[str, Any]
|
A dictionary containing the model's state, including the base model, noise layer, and possibly state dict. |
required |
|
bool
|
Whether to trust remote code when loading from the Hugging Face Hub. |
False
|
|
str | PathLike[str] | None
|
The path or huggingface reference to a third-party model to load. This is useful when loading SGTs whose internal structure depends on transformers which are not importable directly through transformers, but are present on the Hugging Face Hub. |
None
|
Added in version v3.18.0. Add support for serializing NoisyModel instances.
deserialize_base_model
classmethod
¶
deserialize_base_model(
state: Mapping[str, Any],
trust_remote_code: bool = False,
third_party_model_path: (
str | PathLike[str] | None
) = None,
) -> nn.Module
Deserialize the base model from its serialized representation.
This is used by __setstate__ to reconstruct the base model from its serialized representation.
In general, this does not load the state dict, since the noisy model itself will handle loading the state dict. However, the deserialization may need to involve loading model weights from a pretrained checkpoint, e.g. when the base model is a Hugging Face Transformers model and the serialization includes a Hugging Face config.
The default implementation of deserialize_base_model is to import the class from the base_model_type_str field in the serialization, then call __setstate__ on an instance of that class with the base_state field in the serialization. Subclasses can override this method to customize how the base model is deserialized, e.g. by using the Hugging Face from_pretrained method to load a model from a config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Mapping[str, Any]
|
The serialized representation of the base model. |
required |
|
bool
|
Whether to allow executing remote code when deserializing third-party models. This should only be set to |
False
|
|
str | PathLike[str] | None
|
An optional path to a local directory containing the files needed to deserialize a third-party model, which may include custom code. This is used when deserializing third-party models that require custom code, and should only be used with models from trusted sources. |
None
|
Returns:
| Type | Description |
|---|---|
nn.Module
|
The deserialized base model. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If the deserialized base model is not an instance of |
distillation_context
¶
Prepare the base model to facilitate distillation training by applying losses over the transformed and non-transformed activations.
Note
This context manager assumes that the output of the base_model is a mutable mapping with a logits key.
Returns:
| Type | Description |
|---|---|
contextlib.ExitStack
|
A context manager that detaches the hooks when exited. |
Added in version v2.6.0.
forward
¶
Call the base_model, applying the noise_layer to the target_parameter or target_layer output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
Positional arguments to |
required |
|
Tensor | None
|
An optional mask that selects the elements of the |
None
|
|
Any
|
Keyword arguments to |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The result of |
reset_parameters
¶
Reinitialize parameters and buffers.
This method is useful for initializing tensors created on the meta device.
serialize_base_model
¶
Serialize the base model.
This is used by getstate to get a JSON-serializable representation of the base model. This does not, in general, have to include a copy of the state dict, since the noisy model itself will store the base model state dict within its own state dict.
The default implementation of serialize_base_model is to remove any hooks,
then call base_model.__getstate__(), which works for some models, but
some models may need to override this method to return a more
JSON-serializable representation of the base model, e.g. by returning
the base model's Hugging Face config.
Subclasses can override this method to customize how the base model is serialized, e.g. by returning a Hugging Face config instead of the instance dictionary.
More common than subclassing NoisyModel, however, would be to wrap the base
model in a custom class that implements its own JSON-serializable/loadable
__getstate__/__setstate__ methods, and then pass that custom wrapper
instance in as the base model.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A serialized representation of the base model. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If the base model's |
ValueError
|
If the base model's state contains non-JSON-serializable values, which cannot be serialized as part of the
|
Added in version v3.18.0. Add support for serializing NoisyModel instances.
serialize_init_passthrough_kwargs
¶
Serialize the keyword arguments needed to re-initialize the NoisyModel from its serialized base model
and noise layer.
This is used by __getstate__ to get a JSON-serializable representation of the keyword arguments needed to
re-initialize the NoisyModel from its serialized base model and noise layer.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the keyword arguments needed to re-initialize the |
dict[str, Any]
|
form. |
Added in version v3.18.0. Add support for serializing NoisyModel instances.
TruncatedModule
¶
Bases: Module, Generic[ModuleT]
A module that wraps another module that interrupts the forward pass when a specified truncation point is reached.
This truncation happens by temporarily adding a hook to the truncation point that raises a
TruncationExecutionFinished exception which is then caught by
the TruncatedModule forward and the output of the truncation point is returned.
Examples:
Instantiating a TruncatedModule with a Binary Classification model and a truncation point:
>>> model = torch.nn.Sequential(
... torch.nn.Linear(10, 20),
... torch.nn.ReLU(),
... torch.nn.Linear(20, 30),
... torch.nn.ReLU(),
... torch.nn.Linear(30, 40),
... torch.nn.ReLU(),
... torch.nn.Linear(40, 2),
... )
>>> truncation_layer = model[1]
>>> truncated_model = TruncatedModule(model, truncation_layer)
Using the TruncatedModule to get the output of the truncation point:
>>> input = torch.randn(1, 10)
>>> output = truncated_model(input)
>>> # Note that shape of the output has the output_shape of the truncation point, not the full model
>>> assert output.shape == (1, 20)
The base model of the TruncatedModule is completely unaffected by the truncation:
>>> base_output = model(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
The base model is also accessible directly through the module attribute of the TruncatedModule:
>>> base_output = truncated_model.module(input)
>>> assert base_output.shape == (1, 2) # Binary classification output shape
Methods:
| Name | Description |
|---|---|
__init__ |
Initialize the |
forward |
Forward pass of the |
lazy_register_truncation_hook |
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached. |
truncation_hook |
Intercept the output of the truncation point and raise a |
__init__
¶
__init__(module: ModuleT, truncation_point: Module) -> None
Initialize the TruncatedModule with the provided module and truncation point.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
ModuleT
|
The module to wrap. |
required |
|
Module
|
The submodule of the provided module at which to interrupt the forward pass. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the truncation point is not a submodule of the provided module. |
forward
¶
Forward pass of the TruncatedModule that interrupts the forward pass when the truncation point is reached.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
The positional arguments to pass to the wrapped module. |
required |
|
Any
|
The keyword arguments to pass to the wrapped module. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The output of the truncation point submodule. |
Raises:
| Type | Description |
|---|---|
HookNotCalledError
|
If the truncation hook is not called, meaning the truncation point was not reached. |
lazy_register_truncation_hook
¶
Create a prehook that will be added to the truncation point to interrupt the forward pass when the truncation point is reached.
Returns:
| Type | Description |
|---|---|
_HandlerWrapper
|
A handler wrapper that contains the hook that was added to the truncation point. |
truncation_hook
staticmethod
¶
Intercept the output of the truncation point and raise a TruncationExecutionFinished exception containing that output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Module
|
The truncation point submodule. Unused. |
required |
|
Any
|
The arguments passed to the truncation point. Unused. |
required |
|
Tensor
|
The output of the truncation point. This is the output that will be returned by the |
required |
Raises:
| Type | Description |
|---|---|
TruncationExecutionFinished
|
Always, in order to interrupt the wrapped model's |