Skip to content

image_similarity

Module for image-similarity privacy losses.

Provides differentiable losses that quantify how much structural and linear-dependence information remains between a noisy (Stained Glass-transformed) image and its clean counterpart, after a battery of non-learned attacker transforms (clamp, min-max, Sobel, Gaussian blur, gamma, wavelet denoise, z-score-sigmoid). Use these as a privacy regularizer when training image cloaks.

Design notes — alternatives considered.

Pixel-space metrics (L1 / L2). Direct |noisy - clean| or (noisy - clean)^2 were considered first. They are trivially defeated by an attacker that learns a linear de-noising filter, and they treat structural and textural information identically. We chose structural similarity metrics instead because they capture perceptual differences the way a downstream classifier or vision encoder would.

Single-scale MS-SSIM only. Earlier prototypes compared the bare denormalized noisy_01 against clean_01 with a single MS-SSIM call. In experiments this term drove MS-SSIM to very small values while the obfuscation remained weak — an attacker could still recover recognizable structure from the noisy image via simple post-processing (e.g. min-max normalization, gamma correction, denoising). The current multi-attacker formulation evaluates MS-SSIM (and squared Pearson) after applying each candidate attacker transform on both sides, so the loss penalizes residue that survives those attacks. The squared-Pearson family then catches residual linear dependence in directions that MS-SSIM's structural measure only partially penalizes (Pearson is affine-invariant, so it covers a complementary failure mode).

Other perceptual metrics (LPIPS, DISTS, FSIM, PSNR). Not included to keep the loss attacker-free (no learned reference model whose weights could be reverse-engineered) and lightweight. They remain candidates for future additions if a use case appears.

Functions:

Name Description
multi_attacker_privacy_loss

Compute a 15-component privacy loss against non-learned image attackers.

pearson_sq_mean

Mean across the batch of the squared Pearson correlation between pred and target.

safe_ms_ssim

Compute functional MS-SSIM between pred and target, replacing NaN/Inf with zero.

Attributes:

Name Type Description
DEFAULT_GAMMA_ATTACK_VALUES Final[tuple[float, ...]]

Gamma values used by the gamma-correction attacker family (covers both dark- and bright-region recovery).

DEFAULT_GAUSSIAN_BLUR_SIGMA Final[float]

Sigma of the Gaussian-blur attacker, in pixels of the [0, 1]-denormalized image space.

DEFAULT_WAVELET_SOFT_THRESHOLD Final[float]

Soft-threshold magnitude applied to the detail wavelet coefficients in the wavelet-denoise attacker.

MS_SSIM_COMPONENT_NAMES Final[tuple[str, ...]]

Names of the nine MS-SSIM privacy components produced by multi_attacker_privacy_loss.

PEARSON_COMPONENT_NAMES Final[tuple[str, ...]]

Names of the six squared-Pearson privacy components produced by multi_attacker_privacy_loss.

DEFAULT_GAMMA_ATTACK_VALUES module-attribute

DEFAULT_GAMMA_ATTACK_VALUES: Final[tuple[float, ...]] = (
    0.5,
    2.0,
)

Gamma values used by the gamma-correction attacker family (covers both dark- and bright-region recovery).

DEFAULT_GAUSSIAN_BLUR_SIGMA module-attribute

DEFAULT_GAUSSIAN_BLUR_SIGMA: Final[float] = 2.0

Sigma of the Gaussian-blur attacker, in pixels of the [0, 1]-denormalized image space.

DEFAULT_WAVELET_SOFT_THRESHOLD module-attribute

DEFAULT_WAVELET_SOFT_THRESHOLD: Final[float] = 0.1

Soft-threshold magnitude applied to the detail wavelet coefficients in the wavelet-denoise attacker.

MS_SSIM_COMPONENT_NAMES module-attribute

MS_SSIM_COMPONENT_NAMES: Final[tuple[str, ...]] = (
    "rgb_clamp",
    "rgb_minmax",
    "gray_clamp",
    "gray_minmax",
    "sobel_gray",
    "blur_gray",
    "standardize_gray",
    "gamma_gray",
    "wavelet_gray",
)

Names of the nine MS-SSIM privacy components produced by multi_attacker_privacy_loss.

PEARSON_COMPONENT_NAMES module-attribute

PEARSON_COMPONENT_NAMES: Final[tuple[str, ...]] = (
    "pearson_rgb",
    "pearson_gray",
    "pearson_sobel",
    "pearson_blur",
    "pearson_gamma",
    "pearson_wavelet",
)

Names of the six squared-Pearson privacy components produced by multi_attacker_privacy_loss.

multi_attacker_privacy_loss

multi_attacker_privacy_loss(
    noisy_01: Tensor,
    clean_01: Tensor,
    *,
    gamma_values: Sequence[float] = (0.5, 2.0),
    blur_sigma: float = 2.0,
    wavelet_threshold: float = 0.1,
    compute_ms_ssim: bool = True,
    compute_pearson: bool = True
) -> dict[str, torch.Tensor]

Compute a 15-component privacy loss against non-learned image attackers.

Both inputs must already be in the [0, 1] range — denormalize from any model normalization (e.g. ImageNet) before calling. The noisy image is compared against the clean image under fifteen different attacker transforms — nine MS-SSIM terms that measure structural similarity and six squared-Pearson terms that measure residual linear dependence. All fifteen are averaged with equal weights into the two aggregate keys. For "lossy" transforms (blur, sobel, gamma, wavelet, standardize), the same transform is applied to the clean side too so the loss floor stays at a clean 1.0 and the gradient signal is purely driven by noise residue.

Structural similarity terms (MS-SSIM). Contrast-recovery attackers compared directly against clean/clean_gray:

  • rgb_clamp: display-level RGB — clamp noisy RGB to [0, 1].
  • rgb_minmax: tensor-level RGB — min-max normalize noisy RGB.
  • gray_clamp: display-level luma — clamp noisy luma.
  • gray_minmax: tensor-level luma — reduce to luma then min-max normalize.

Lossy-transform attackers (T applied to both sides, compared in the transformed space):

  • sobel_gray: Sobel gradient magnitude of luma (min-max normalized).
  • blur_gray: Gaussian low-pass of luma at blur_sigma pixels.
  • standardize_gray: per-image z-score + sigmoid of luma.
  • gamma_gray: mean MS-SSIM across gamma_values.
  • wavelet_gray: single-level Haar wavelet soft-threshold denoising.

Linear-dependence terms (squared Pearson). Pearson is affine-invariant, so applying it to min-max / standardize / other affine transforms is redundant with applying it to the raw luma — those variants are intentionally omitted. The included transforms all break linearity:

  • pearson_rgb: flattened RGB, using clamp(noisy_rgb) vs clean_rgb.
  • pearson_gray: flattened luma, using clamp(noisy_gray) vs clean_gray.
  • pearson_sobel: Sobel magnitudes of both sides.
  • pearson_blur: Gaussian-blurred luma on both sides.
  • pearson_gamma: mean over gamma_values of squared Pearson after gamma correction on both sides.
  • pearson_wavelet: wavelet-denoised luma on both sides.

If noisy_01 contains NaN/Inf values, returns a dict of zeros (and emits a warning) rather than propagating invalid values through the backward pass. The root cause (e.g. an unbounded mean estimator output) should be addressed at the noise layer.

Privacy: this loss handles user-provided image data. Callers should ensure that the invocation context does not log or persist clean_01 or noisy_01.

Parameters:

Name Type Description Default

noisy_01

Tensor

Stained-Glass-transformed image tensor of shape [B, 3, H, W] with values in [0, 1]. Gradients flow through this argument.

required

clean_01

Tensor

Clean image tensor of shape [B, 3, H, W] with values in [0, 1].

required

gamma_values

Sequence[float]

Gamma values used for the gamma_gray and pearson_gamma attackers.

(0.5, 2.0)

blur_sigma

float

Sigma of the Gaussian blur applied for the blur_gray and pearson_blur attackers.

2.0

wavelet_threshold

float

Soft-threshold magnitude applied to wavelet detail bands for the wavelet_gray and pearson_wavelet attackers.

0.1

compute_ms_ssim

bool

If False, skip the nine MS-SSIM metric calls; the corresponding components and ms_ssim_total are returned as zero tensors. Useful when only the squared-Pearson family is being weighted in the composite loss.

True

compute_pearson

bool

If False, skip the six squared-Pearson metric calls; the corresponding components and pearson_total are returned as zero tensors. Useful when only the MS-SSIM family is being weighted in the composite loss.

True

Returns:

Type Description
dict[str, torch.Tensor]

A dict with one key per component listed above, plus two aggregate keys:

dict[str, torch.Tensor]
  • ms_ssim_total: unweighted mean of the nine structural-similarity components.
dict[str, torch.Tensor]
  • pearson_total: unweighted mean of the six linear-dependence components.
dict[str, torch.Tensor]

All values are scalar tensors in [0, 1].

Raises:

Type Description
ValueError

If gamma_values is empty, blur_sigma is not positive/finite, or wavelet_threshold is negative/non-finite.

Added in version v3.39.0. 15-component multi-attacker privacy loss.

pearson_sq_mean

pearson_sq_mean(
    pred: Tensor, target: Tensor, *, eps: float = 1e-12
) -> torch.Tensor

Mean across the batch of the squared Pearson correlation between pred and target.

Pearson correlation is affine-invariant (invariant to shift/scale of either argument), so this term captures residual linear dependence in directions that MS-SSIM's structural measure only partially penalizes. Squaring maps the result into [0, 1] with a smooth minimum at 0 (no linear dependence) and a smooth maximum at 1 (perfect linear dependence up to sign), keeping gradients finite across the entire range — unlike |pearson| which is non-differentiable at zero.

Both the mean-subtraction and the variance terms are kept attached to the graph (not detached), so the gradient with respect to pred reflects the true Pearson derivative.

Parameters:

Name Type Description Default

pred

Tensor

Tensor of shape [B, ...]. Gradients flow through this argument.

required

target

Tensor

Tensor of shape [B, ...] with the same non-batch numel as pred.

required

eps

float

Floor applied to the product of variances before the sqrt in the Pearson denominator, so the sqrt's gradient stays finite for near-constant inputs.

1e-12

Returns:

Type Description
torch.Tensor

Scalar tensor in [0, 1]. Lower means less linear dependence between pred and target.

Added in version v3.39.0. Image-similarity privacy loss family.

safe_ms_ssim

safe_ms_ssim(
    pred: Tensor,
    target: Tensor,
    *,
    data_range: float = 1.0,
    normalize: Literal["relu", "simple"] | None = "relu"
) -> torch.Tensor

Compute functional MS-SSIM between pred and target, replacing NaN/Inf with zero.

Wraps torchmetrics.functional.image.multiscale_structural_similarity_index_measure. Returns 0 (and emits a warning) on NaN/Inf so a single bad batch does not poison training via a NaN gradient.

Parameters:

Name Type Description Default

pred

Tensor

Predicted image tensor of shape [B, C, H, W] with values in [0, data_range].

required

target

Tensor

Target image tensor of shape [B, C, H, W] with values in [0, data_range].

required

data_range

float

The dynamic range of the inputs, forwarded to torchmetrics. The default of 1.0 assumes both arguments are pre-denormalized into [0, 1].

1.0

normalize

Literal['relu', 'simple'] | None

Normalization mode for negative MS-SSIM contributions, forwarded to torchmetrics.

'relu'

Returns:

Type Description
torch.Tensor

Scalar MS-SSIM tensor with the same dtype/device as pred.

Added in version v3.39.0. Image-similarity privacy loss family.