cloak
Functions:
Name | Description |
---|---|
composite_cloak_loss_factory |
Create a loss function to train a Stained Glass Transform using |
composite_cloak_loss_factory
¶
composite_cloak_loss_factory(
noisy_model: NoisyModel[ModuleT, ..., NoiseLayerT],
loss_function: Callable[LossFunctionP, Tensor],
alpha: float,
respect_std_mask: bool = True,
) -> tuple[
Callable[sg_transform_loss.LossFunctionP, torch.Tensor],
Callable[[], ComponentLossesDict],
Callable[[], HyperparametersDict],
]
Create a loss function to train a Stained Glass Transform using negative_log_mean
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
NoisyModel[ModuleT, ..., NoiseLayerT]
|
The model containing both the base model and the Stained Glass Transform. |
required |
|
Callable[LossFunctionP, Tensor]
|
The base model task loss function to wrap. |
required |
|
float
|
The interpolation factor between the task loss (maximizing task performance) and the Stained Glass Transform loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher task performance and 1 corresponds to higher transformation strength. |
required |
|
bool
|
Some NoiseLayers' |
True
|
Returns:
Type | Description |
---|---|
tuple[Callable[sg_transform_loss.LossFunctionP, torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
A tuple of 3 functions: the composite loss function, a function to retrieve the loss components, and a function to retrieve the |
tuple[Callable[sg_transform_loss.LossFunctionP, torch.Tensor], Callable[[], ComponentLossesDict], Callable[[], HyperparametersDict]]
|
hyperparameters. These functions may be called at most once each after a forward pass through both models. |