Skip to content

transform

Functions:

Name Description
alphaless_loss

Compute the gradients for each loss, and for any parameters shared by both losses, modify the gradient such that it is orthogonal to

composite_loss

Interpolate between the task loss and the noise loss term.

hook_loss_wrapper

Attach hooks that extract activation tensors from the corresponding layers, and return a wrapped loss function with the same

alphaless_loss

alphaless_loss(
    loss1: Tensor,
    loss2: Tensor,
    /,
    backward_wrapper: _BackwardWrapper | None = None,
    grad_scaler: _GradScaler | None = None,
) -> torch.Tensor

Compute the gradients for each loss, and for any parameters shared by both losses, modify the gradient such that it is orthogonal to the smaller of the two computed gradients.

This loss function is symmetric with respect to the order of the given losses.

Parameters:

Name Type Description Default

loss1

Tensor

One of the loss tensors to combine.

required

loss2

Tensor

One of the loss tensors to combine.

required

backward_wrapper

_BackwardWrapper | None

The managed grad scaler to use for the backward pass (like accelerate.Accelerator or lightning.fabric.fabric.Fabric).

None

grad_scaler

_GradScaler | None

The gradient scaler to use for the backward pass (like torch.cuda.amp.GradScaler or torch.cpu.amp.GradScaler).

None

Raises:

Type Description
ValueError

If both backward_wrapper and grad_scaler are provided.

ValueError

If the two losses have no trainable parameters in common.

Returns:

Type Description
torch.Tensor

The sum of the two losses, detached from the original graph so that a subsequent call to backward() is a no-op.

Added in version 0.55.0.

composite_loss

composite_loss(
    task_loss: Tensor, noise_loss: Tensor, alpha: float
) -> torch.Tensor

Interpolate between the task loss and the noise loss term.

Higher values of alpha weigh the noise loss term more heavily, while lower values weight the task loss more heavily.

Parameters:

Name Type Description Default

task_loss

Tensor

The loss tensor for the task-specific criterion.

required

noise_loss

Tensor

The loss tensor for the noise layer regularization term.

required

alpha

float

The interpolation factor between the task loss and the noise loss. Should be in the range [0, 1].

required

Returns:

Type Description
torch.Tensor

The composite loss tensor.

Added in version 0.55.0.

hook_loss_wrapper

hook_loss_wrapper(
    loss_function: Callable[
        LossFunctionP, LossFunctionReturnT
    ],
    noise_loss_function: Callable[
        [
            LossFunctionReturnT,
            ActivationsDict,
            ComponentLossesDict,
            HyperparametersDict,
        ],
        LossFunctionReturnT,
    ],
    activation_hooks: ActivationHooksDict,
) -> tuple[
    Callable[LossFunctionP, LossFunctionReturnT],
    Callable[[], ComponentLossesDict],
    Callable[[], HyperparametersDict],
]

Attach hooks that extract activation tensors from the corresponding layers, and return a wrapped loss function with the same interface as the original that utilizes these activations.

Parameters:

Name Type Description Default

loss_function

Callable[LossFunctionP, LossFunctionReturnT]

The original loss function.

required

noise_loss_function

Callable[[LossFunctionReturnT, ActivationsDict, ComponentLossesDict, HyperparametersDict], LossFunctionReturnT]

A function that accepts the output of loss_function and a dictionary of activation tensors. Must return the same type of output as loss_function.

required

activation_hooks

ActivationHooksDict

A mapping of activation names to tuples of modules and their corresponding activation tensor-returning hooks.

required

Returns:

Type Description
tuple[Callable[LossFunctionP, LossFunctionReturnT], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A wrapped loss function with the same signature as loss_function that utilizes the activations extracted by the hooks.

Examples:

>>> class MiniConvModel(nn.Module):
...     def __init__(self) -> None:
...         super().__init__()
...         self.model = nn.Sequential(
...             nn.MaxPool2d(kernel_size=(28, 28)),
...             nn.Conv2d(3, 1, kernel_size=(1, 1)),
...             nn.ReLU(),
...             nn.Flatten(),
...         )
...         self.classifier = nn.Linear(in_features=64, out_features=2, bias=True)
...
...     def forward(self, input: torch.Tensor) -> torch.Tensor:
...         return self.classifier(self.model(input))
>>> model = MiniConvModel()
>>> loss_function = nn.functional.cross_entropy
>>> activation_hooks = {
...     "conv_weight": (
...         model.model[1],
...         lambda module, args, kwargs, output: module.weight,
...     ),
...     "classifier_weight": (
...         model.classifier,
...         lambda module, args, kwargs, output: module.weight,
...     ),
... }
>>> def weight_norm_loss_function(
...     loss: torch.Tensor,
...     activations: dict[str, torch.Tensor],
...     losses: dict[str, float],
...     hyperparameters: dict[str, Any],
... ) -> torch.Tensor:
...     hyperparameters["p"] = "fro"
...     hyperparameters["beta"] = 0.1
...     losses["conv_weight_norm"] = activations["conv_weight"].norm(
...         p=hyperparameters["p"]
...     )
...     hyperparameters["gamma"] = 0.2
...     losses["classifier_weight_norm"] = activations["classifier_weight"].norm(
...         p=hyperparameters["p"]
...     )
...     return (
...         loss
...         + hyperparameters["beta"] * losses["conv_weight_norm"]
...         + hyperparameters["gamma"] * losses["classifier_weight_norm"]
...     )
>>> wrapped_loss, get_losses, get_hyperparameters = hook_loss_wrapper(
...     loss_function,
...     weight_norm_loss_function,
...     activation_hooks,
... )
>>> input = torch.randn(1, 3, 224, 224)
>>> target = torch.randint(0, 2, (1,))
>>> output = model(input)
>>> loss = wrapped_loss(output, target)
>>> loss.backward()
>>> get_losses()
{'conv_weight_norm': tensor(...), 'classifier_weight_norm': tensor(...)}
>>> get_hyperparameters()
{'p': 'fro', 'beta': 0.1, 'gamma': 0.2}

Changed in version 0.71.0: `noise_loss_function` is now expected to additionally accept a dictionary for tracking individual loss components and hyperparameters. `hook_loss_wrapper` now also returns 2 additional functions, which return these dictionaries.

Changed in version 0.66.0: The `module` parameter was removed. `activation_hooks` now accepts a mapping of activation names to `tuple` of `Module` and their corresponding activation tensor-returning hooks. Now, modules are not required be submodules of a single parent module.

Added in version 0.55.0.