transform
Module for loss transformation utilities.
Functions:
| Name | Description |
|---|---|
alphaless_loss |
Compute the gradients for each loss, and for any parameters shared by both losses, modify the gradient such that it is orthogonal to |
composite_loss |
Interpolate between the task loss and the noise loss term. |
hook_loss_wrapper |
Attach hooks that extract activation tensors from the corresponding layers, and return a wrapped loss function with the same |
alphaless_loss
¶
alphaless_loss(
loss1: Tensor,
loss2: Tensor,
/,
backward_wrapper: _BackwardWrapper | None = None,
grad_scaler: _GradScaler | None = None,
) -> torch.Tensor
Compute the gradients for each loss, and for any parameters shared by both losses, modify the gradient such that it is orthogonal to the smaller of the two computed gradients.
This loss function is symmetric with respect to the order of the given losses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
One of the loss tensors to combine. |
required |
|
Tensor
|
One of the loss tensors to combine. |
required |
|
_BackwardWrapper | None
|
The managed grad scaler to use for the backward pass (like accelerate.Accelerator or lightning.fabric.fabric.Fabric). |
None
|
|
_GradScaler | None
|
The gradient scaler to use for the backward pass (like torch.cuda.amp.GradScaler or |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If both |
ValueError
|
If the two losses have no trainable parameters in common. |
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
The sum of the two losses, detached from the original graph so that a subsequent call to |
composite_loss
¶
Interpolate between the task loss and the noise loss term.
Higher values of alpha weigh the noise loss term more heavily, while lower values weight the task loss more heavily.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The loss tensor for the task-specific criterion. |
required |
|
Tensor
|
The loss tensor for the noise layer regularization term. |
required |
|
float
|
The interpolation factor between the task loss and the noise loss. Should be in the range [0, 1]. |
required |
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
The composite loss tensor. |
hook_loss_wrapper
¶
hook_loss_wrapper(
loss_function: Callable[
LossFunctionP, LossFunctionReturnT
],
noise_loss_function: Callable[
[
LossFunctionReturnT,
ActivationsDict,
ComponentLossesDict,
HyperparametersDict,
],
LossFunctionReturnT,
],
activation_hooks: ActivationHooksDict,
) -> tuple[
Callable[LossFunctionP, LossFunctionReturnT],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Attach hooks that extract activation tensors from the corresponding layers, and return a wrapped loss function with the same interface as the original that utilizes these activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Callable[LossFunctionP, LossFunctionReturnT]
|
The original loss function. |
required |
|
Callable[[LossFunctionReturnT, ActivationsDict, ComponentLossesDict, HyperparametersDict], LossFunctionReturnT]
|
A function that accepts the output of |
required |
|
ActivationHooksDict
|
A mapping of activation names to tuples of modules and their corresponding activation tensor-returning hooks. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Callable[LossFunctionP, LossFunctionReturnT], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A wrapped loss function with the same signature as |
Examples:
>>> class MiniConvModel(nn.Module):
... def __init__(self) -> None:
... super().__init__()
... self.model = nn.Sequential(
... nn.MaxPool2d(kernel_size=(28, 28)),
... nn.Conv2d(3, 1, kernel_size=(1, 1)),
... nn.ReLU(),
... nn.Flatten(),
... )
... self.classifier = nn.Linear(in_features=64, out_features=2, bias=True)
...
... def forward(self, input: torch.Tensor) -> torch.Tensor:
... return self.classifier(self.model(input))
>>> model = MiniConvModel()
>>> loss_function = nn.functional.cross_entropy
>>> activation_hooks = {
... "conv_weight": (
... model.model[1],
... lambda module, args, kwargs, output: module.weight,
... ),
... "classifier_weight": (
... model.classifier,
... lambda module, args, kwargs, output: module.weight,
... ),
... }
>>> def weight_norm_loss_function(
... loss: torch.Tensor,
... activations: dict[str, torch.Tensor],
... losses: dict[str, float],
... hyperparameters: dict[str, Any],
... ) -> torch.Tensor:
... hyperparameters["p"] = "fro" # codespell:ignore fro
... hyperparameters["beta"] = 0.1
... losses["conv_weight_norm"] = activations["conv_weight"].norm(
... p=hyperparameters["p"]
... )
... hyperparameters["gamma"] = 0.2
... losses["classifier_weight_norm"] = activations["classifier_weight"].norm(
... p=hyperparameters["p"]
... )
... return (
... loss
... + hyperparameters["beta"] * losses["conv_weight_norm"]
... + hyperparameters["gamma"] * losses["classifier_weight_norm"]
... )
>>> wrapped_loss, get_losses, get_hyperparameters = hook_loss_wrapper(
... loss_function,
... weight_norm_loss_function,
... activation_hooks,
... )
>>> input = torch.randn(1, 3, 224, 224)
>>> target = torch.randint(0, 2, (1,))
>>> output = model(input)
>>> loss = wrapped_loss(output, target)
>>> loss.backward()
>>> get_losses()
{'conv_weight_norm': tensor(...), 'classifier_weight_norm': tensor(...)}
>>> get_hyperparameters()
{'p': 'fro', 'beta': 0.1, 'gamma': 0.2}