Skip to content

trainer

StainedGlassTrainer

Bases: Trainer, Generic[M, NLP, NL]

Simple subclass of transformers.Trainer to add StainedGlass-specific optimizers and loss functions.

Contains all the features of the original Trainer class, including training and evaluation loops for PyTorch, optimized for 🤗 Transformers.

TODO: include examples for model vs model_init, specifying optimizer with the init, etc...

__init__

__init__(model: NoisyModel[M, NLP, NL], *args: Any, alpha: float, save_only_noise_layer: bool = False, **kwargs: Any) -> None

Initialize a StainedGlassTrainer.

For additional arguments, see the documentation for transformers.Trainer: https://huggingface.co/docs/transformers/main_classes/trainer

Parameters:

Name Type Description Default
model NoisyModel[M, NLP, NL]

The model to train.

required
alpha float

The alpha parameter for the noise loss function.

required
save_only_noise_layer bool

Whether to save only the noise layer when saving the model of type NoisyTransformerModel.

False
*args Any

Arguments to pass to the superclass.

required
**kwargs Any

Keyword arguments to pass to the superclass.

required

compute_loss

compute_loss(model: Module, inputs: Any, return_outputs: bool = False) -> torch.Tensor | tuple[torch.Tensor, Any]

Compute the loss for training Stained Glass Transforms.

Parameters:

Name Type Description Default
model Module

The model to compute the loss for.

required
inputs Any

The inputs to the model.

required
return_outputs bool

Whether to return the outputs of the model alongside the loss.

False

Returns:

Type Description
torch.Tensor | tuple[torch.Tensor, Any]

The loss for training Stained Glass Transforms.

mock_inspect_signature

mock_inspect_signature(func: Callable[..., Any], noisy_model: NoisyModel[M, NLP, NL]) -> inspect.Signature

Mock inspect.signature to return the signature of the base model's forward function, if called on a NoisyModel forward, otherwise return the original signature.

Parameters:

Name Type Description Default
func Callable[..., Any]

The function to inspect.

required
noisy_model NoisyModel[M, NLP, NL]

The NoisyModel to check against.

required

Returns:

Type Description
inspect.Signature

Appropriate function signature.