Skip to content

cloak

Functions:

Name Description
composite_cloak_loss_factory

Create a loss function to train a Stained Glass Transform using negative_log_mean.

composite_cloak_loss_factory

composite_cloak_loss_factory(
    noisy_model: NoisyModel[ModuleT, ..., NoiseLayerT],
    loss_function: Callable[LossFunctionP, Tensor],
    alpha: float,
    respect_std_mask: bool = True,
) -> tuple[
    Callable[sg_transform_loss.LossFunctionP, torch.Tensor],
    Callable[[], ComponentLossesDict],
    Callable[[], HyperparametersDict],
]

Create a loss function to train a Stained Glass Transform using negative_log_mean.

Parameters:

Name Type Description Default

noisy_model

NoisyModel[ModuleT, ..., NoiseLayerT]

The model containing both the base model and the Stained Glass Transform.

required

loss_function

Callable[LossFunctionP, Tensor]

The base model task loss function to wrap.

required

alpha

float

The interpolation factor between the task loss (maximizing task performance) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher task performance and 1 corresponds to higher transformation strength.

required

respect_std_mask

bool

Some NoiseLayers' std_estimator returns a mask tensor which determines which elements of the inputs are masked out as part of the transformation. If True, the loss function will respect this mask when computing the noise loss (i.e. only unmasked elements will contribute to the noise loss). If False, the loss function will ignore the mask and compute the noise loss over all elements. This parameter is ignored if the NoiseLayer does not provide a mask. For historical reasons, for NoiseLayers inherited from PatchCloakNoiseLayer, this parameter is ignored and the mask is never respected.

True

Returns:

Type Description
tuple[Callable[sg_transform_loss.LossFunctionP, torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

A tuple of 3 functions: the composite loss function, a function to retrieve the loss components, and a function to retrieve the

tuple[Callable[sg_transform_loss.LossFunctionP, torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]

hyperparameters. These functions may be called at most once each after a forward pass through both models.