trainer
StainedGlassTrainer
¶
Bases: Trainer
, Generic[M, NLP, NL]
Simple subclass of transformers.Trainer to add StainedGlass-specific optimizers and loss functions.
Contains all the features of the original Trainer class, including training and evaluation loops for PyTorch, optimized for 🤗 Transformers.
TODO: include examples for model vs model_init, specifying optimizer with the init, etc...
__init__
¶
__init__(
model: NoisyModel[M, NLP, NL],
*args: Any,
alpha: float,
save_only_noise_layer: bool = False,
**kwargs: Any
) -> None
Initialize a StainedGlassTrainer.
For additional arguments, see the documentation for transformers.Trainer: https://huggingface.co/docs/transformers/main_classes/trainer
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
NoisyModel[M, NLP, NL]
|
The model to train. |
required |
alpha |
float
|
The alpha parameter for the noise loss function. |
required |
save_only_noise_layer |
bool
|
Whether to save only the noise layer when saving the model of type |
False
|
*args |
Any
|
Arguments to pass to the superclass. |
()
|
**kwargs |
Any
|
Keyword arguments to pass to the superclass. |
{}
|
compute_loss
¶
compute_loss(
model: Module, inputs: Any, return_outputs: bool = False
) -> Tensor | tuple[Tensor, Any]
Compute the loss for training Stained Glass Transforms.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
The model to compute the loss for. |
required |
inputs |
Any
|
The inputs to the model. |
required |
return_outputs |
bool
|
Whether to return the outputs of the model alongside the loss. |
False
|
Returns:
Type | Description |
---|---|
Tensor | tuple[Tensor, Any]
|
The loss for training Stained Glass Transforms. |
mock_inspect_signature
¶
mock_inspect_signature(
func: Callable[..., Any],
noisy_model: NoisyModel[M, NLP, NL],
) -> Signature
Mock inspect.signature to return the signature of the base model's forward function, if called on a NoisyModel forward, otherwise return the original signature.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., Any]
|
The function to inspect. |
required |
noisy_model |
NoisyModel[M, NLP, NL]
|
The NoisyModel to check against. |
required |
Returns:
Type | Description |
---|---|
Signature
|
Appropriate function signature. |