TensorFlow Adapter¶
The notebook demonstrates how to use Stained Glass Core to create a Stained Glass Transform for a TensorFlow model using TensorFlowAdapter, a PyTorch-based TensorFlow model wrapper.
Installation¶
Finding versions of PyTorch and TensorFlow that were compiled with compatible CUDA versions is a difficult. These versions have been verified to be compatible:
torch~=2.6.0 and tensorflow[and-cuda]~=2.17.0:
pip install uv
uv pip install "tensorflow>=2.17.0,<2.18.0" ultralytics
uv pip install -e .[all,torch-2-6]
uv pip install "tensorflow[and-cuda]>=2.17.0,<2.18.0"
torch~=2.5.0 and tensorflow[and-cuda]~=2.16.0:
In [1]:
Copied!
from __future__ import annotations
import functools
import math
import os
import random
import re
import types
from typing import Literal
import datasets
import keras
import lightning
import lightning.pytorch.loggers
import tensorflow as tf
import torch
import torch.utils.tensorboard
import torchmetrics
import torchvision.transforms.v2
import ultralytics
import ultralytics.engine.results
from torch import nn
from typing_extensions import override
from stainedglass_core import (
    model as sg_model,
    noise_layer as sg_noise_layer,
    utils as sg_utils,
)
from stainedglass_core.integrations import tensorflow as sg_tensorflow
from stainedglass_core.loss import cloak as sg_cloak_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices("GPU")
print(physical_devices)
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
tf.config.experimental.set_memory_growth(physical_devices[0], True)
from __future__ import annotations
import functools
import math
import os
import random
import re
import types
from typing import Literal
import datasets
import keras
import lightning
import lightning.pytorch.loggers
import tensorflow as tf
import torch
import torch.utils.tensorboard
import torchmetrics
import torchvision.transforms.v2
import ultralytics
import ultralytics.engine.results
from torch import nn
from typing_extensions import override
from stainedglass_core import (
    model as sg_model,
    noise_layer as sg_noise_layer,
    utils as sg_utils,
)
from stainedglass_core.integrations import tensorflow as sg_tensorflow
from stainedglass_core.loss import cloak as sg_cloak_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices("GPU")
print(physical_devices)
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
tf.config.experimental.set_memory_growth(physical_devices[0], True)
/home/matthew/.conda/envs/core39/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm 2025-03-14 12:58:49.115331: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2025-03-14 12:58:49.135288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2025-03-14 12:58:49.157477: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2025-03-14 12:58:49.163992: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-03-14 12:58:49.180946: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2025-03-14 12:58:50.047310: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')] Num GPUs Available: 1
In [2]:
Copied!
from typing import TypedDict
class DataPoint(TypedDict):
    image: torch.Tensor
class BatchedDataPoint(TypedDict):
    image: list[torch.Tensor]
from typing import TypedDict
class DataPoint(TypedDict):
    image: torch.Tensor
class BatchedDataPoint(TypedDict):
    image: list[torch.Tensor]
In [ ]:
Copied!
import dataclasses
import inspect
from typing import Any
from typing_extensions import Self
@dataclasses.dataclass
class Hyperparameters:
    @classmethod
    def from_dict(cls, **kwargs: Any) -> Self:
        cls_parameters = inspect.signature(cls).parameters
        instance = cls(
            **{k: kwargs.pop(k) for k in list(kwargs) if k in cls_parameters}
        )
        for k, v in kwargs.items():
            setattr(instance, k, v)
        return instance
@dataclasses.dataclass
class TransformHyperparameters(Hyperparameters):
    percent_to_mask: float = 0.0
    scale: tuple[float, float] = (1e-4, 1.0)
    shallow: float = 1.0
    rhos_init: float = -4.0
    seed: int | None = None
@dataclasses.dataclass
class LossHyperparameters(Hyperparameters):
    reduction: str = "mean"
@dataclasses.dataclass
class TransformLossHyperparameters(Hyperparameters):
    alpha: float = 0.5
@dataclasses.dataclass
class OptimizerHyperparameters(Hyperparameters):
    lr: float = 3e-3
    fused: bool = True
@dataclasses.dataclass
class MetricHyperparameters(Hyperparameters): ...
@dataclasses.dataclass
class DatasetHyperparameters(Hyperparameters):
    path: str = "detection-datasets/coco"
    """The name of a dataset on https://huggingface.co/datasets or the path to a local dataset."""
    cache_dir: str | None = "/data_fast"
    """The directory to store the downloaded dataset in."""
    num_proc: int | None = 2
    """The number of workers to use when resizing the dataset."""
@dataclasses.dataclass
class DataLoaderHyperparameters(Hyperparameters):
    batch_size: int = 24
    num_workers: int | None = 0
@dataclasses.dataclass
class DetectionHyperparameters(Hyperparameters):
    num_crops: int = 4
import dataclasses
import inspect
from typing import Any
from typing_extensions import Self
@dataclasses.dataclass
class Hyperparameters:
    @classmethod
    def from_dict(cls, **kwargs: Any) -> Self:
        cls_parameters = inspect.signature(cls).parameters
        instance = cls(
            **{k: kwargs.pop(k) for k in list(kwargs) if k in cls_parameters}
        )
        for k, v in kwargs.items():
            setattr(instance, k, v)
        return instance
@dataclasses.dataclass
class TransformHyperparameters(Hyperparameters):
    percent_to_mask: float = 0.0
    scale: tuple[float, float] = (1e-4, 1.0)
    shallow: float = 1.0
    rhos_init: float = -4.0
    seed: int | None = None
@dataclasses.dataclass
class LossHyperparameters(Hyperparameters):
    reduction: str = "mean"
@dataclasses.dataclass
class TransformLossHyperparameters(Hyperparameters):
    alpha: float = 0.5
@dataclasses.dataclass
class OptimizerHyperparameters(Hyperparameters):
    lr: float = 3e-3
    fused: bool = True
@dataclasses.dataclass
class MetricHyperparameters(Hyperparameters): ...
@dataclasses.dataclass
class DatasetHyperparameters(Hyperparameters):
    path: str = "detection-datasets/coco"
    """The name of a dataset on https://huggingface.co/datasets or the path to a local dataset."""
    cache_dir: str | None = "/data_fast"
    """The directory to store the downloaded dataset in."""
    num_proc: int | None = 2
    """The number of workers to use when resizing the dataset."""
@dataclasses.dataclass
class DataLoaderHyperparameters(Hyperparameters):
    batch_size: int = 24
    num_workers: int | None = 0
@dataclasses.dataclass
class DetectionHyperparameters(Hyperparameters):
    num_crops: int = 4
In [4]:
Copied!
from typing_extensions import TypedDict
from stainedglass_core.utils.torch import nn as sg_nn
class Metrics(TypedDict):
    """Models the metrics collected in the different modes of `StainedGlassTensorflowVisionModule`."""
    train: TrainingStepMetrics
    valid: TrainingStepMetrics
class TrainingStepMetrics(nn.Module):
    """Common metrics to be collected during `_step` in both training and validation."""
    def __init__(self) -> None:
        super().__init__()
        self.mean_losses = sg_nn.ModuleDefaultDict(torchmetrics.MeanMetric)
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=1000)
        self.precision = torchmetrics.Precision("multiclass", num_classes=1000)
from typing_extensions import TypedDict
from stainedglass_core.utils.torch import nn as sg_nn
class Metrics(TypedDict):
    """Models the metrics collected in the different modes of `StainedGlassTensorflowVisionModule`."""
    train: TrainingStepMetrics
    valid: TrainingStepMetrics
class TrainingStepMetrics(nn.Module):
    """Common metrics to be collected during `_step` in both training and validation."""
    def __init__(self) -> None:
        super().__init__()
        self.mean_losses = sg_nn.ModuleDefaultDict(torchmetrics.MeanMetric)
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=1000)
        self.precision = torchmetrics.Precision("multiclass", num_classes=1000)
In [ ]:
Copied!
class HuggingFaceDataModule(lightning.LightningDataModule):
    def __init__(
        self,
        image_size: tuple[int, int],
        dataset: DatasetHyperparameters,
        train_dataloader: DataLoaderHyperparameters,
        test_dataloader: DataLoaderHyperparameters,
    ):
        super().__init__()
        self.image_size = image_size
        self.dataset_hyperparameters = dataset
        self.train_dataloader_hyperparameters = train_dataloader
        self.test_dataloader_hyperparameters = test_dataloader
    def _prepare_data(self) -> datasets.DatasetDict:
        """Download and prepare the dataset splits."""
        resize_transform = torchvision.transforms.v2.Compose(
            [
                torchvision.transforms.v2.Resize(self.image_size),
                torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
                torchvision.transforms.v2.RGB(),
            ]
        )
        def resize_map(
            sample: DataPoint | BatchedDataPoint,
        ) -> DataPoint | BatchedDataPoint:
            if isinstance(sample["image"], list):
                return {
                    "image": [
                        resize_transform(image) for image in sample["image"]
                    ]
                }
            return {"image": resize_transform(sample["image"])}
        if os.path.isdir(path := self.dataset_hyperparameters.path) or (
            "__file__" in globals()
            and os.path.isdir(
                path := os.path.join(os.path.dirname(__file__), path)
            )
        ):
            dataset = datasets.load_from_disk(path)
        else:
            dataset = datasets.load_dataset(
                self.dataset_hyperparameters.path,
                cache_dir=self.dataset_hyperparameters.cache_dir,
            )
        assert isinstance(dataset, datasets.DatasetDict)
        def _get_cache_dir(dataset: datasets.Dataset) -> str:
            if not dataset.cache_files:
                raise ValueError("The loaded dataset has no cache files.")
            return os.path.dirname(dataset.cache_files[-1]["filename"])
        def _get_cache_file_names() -> dict[str, str | None]:
            return {
                split_name: os.path.join(
                    _get_cache_dir(dataset_split), "cache", f"{split_name}.map"
                )
                for split_name, dataset_split in dataset.items()
            }
        return (
            dataset.remove_columns(["width", "height", "objects"])
            .with_format("torch")
            .map(
                resize_map,
                batched=False,
                load_from_cache_file=True,
                cache_file_names=_get_cache_file_names(),
                num_proc=self.dataset_hyperparameters.num_proc,
            )
        )
    @override
    def prepare_data(self) -> None:
        """Download and prepare the dataset splits.
        This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
        """
        _ = self._prepare_data()
    @override
    def setup(self, stage: str) -> None:
        """Set up the datasets.
        Args:
            stage: Whether we are fitting (training), testing or predicting.
        """
        dataset = self._prepare_data()
        self.train_dataset = dataset["train"]
        self.val_dataset = dataset["val"]
    @override
    def train_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
        """Create the training dataloader.
        Returns:
            The training dataloader.
        """
        return torch.utils.data.DataLoader(
            self.train_dataset,  # pyright: ignore[reportArgumentType]
            **dataclasses.asdict(self.train_dataloader_hyperparameters),
            shuffle=True,
            pin_memory=True,
        )
    @override
    def val_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
        """Create the training dataloader.
        Returns:
            The training dataloader.
        """
        return torch.utils.data.DataLoader(
            self.val_dataset,  # pyright: ignore[reportArgumentType]
            **dataclasses.asdict(self.test_dataloader_hyperparameters),
            shuffle=False,
            pin_memory=True,
        )
class HuggingFaceDataModule(lightning.LightningDataModule):
    def __init__(
        self,
        image_size: tuple[int, int],
        dataset: DatasetHyperparameters,
        train_dataloader: DataLoaderHyperparameters,
        test_dataloader: DataLoaderHyperparameters,
    ):
        super().__init__()
        self.image_size = image_size
        self.dataset_hyperparameters = dataset
        self.train_dataloader_hyperparameters = train_dataloader
        self.test_dataloader_hyperparameters = test_dataloader
    def _prepare_data(self) -> datasets.DatasetDict:
        """Download and prepare the dataset splits."""
        resize_transform = torchvision.transforms.v2.Compose(
            [
                torchvision.transforms.v2.Resize(self.image_size),
                torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
                torchvision.transforms.v2.RGB(),
            ]
        )
        def resize_map(
            sample: DataPoint | BatchedDataPoint,
        ) -> DataPoint | BatchedDataPoint:
            if isinstance(sample["image"], list):
                return {
                    "image": [
                        resize_transform(image) for image in sample["image"]
                    ]
                }
            return {"image": resize_transform(sample["image"])}
        if os.path.isdir(path := self.dataset_hyperparameters.path) or (
            "__file__" in globals()
            and os.path.isdir(
                path := os.path.join(os.path.dirname(__file__), path)
            )
        ):
            dataset = datasets.load_from_disk(path)
        else:
            dataset = datasets.load_dataset(
                self.dataset_hyperparameters.path,
                cache_dir=self.dataset_hyperparameters.cache_dir,
            )
        assert isinstance(dataset, datasets.DatasetDict)
        def _get_cache_dir(dataset: datasets.Dataset) -> str:
            if not dataset.cache_files:
                raise ValueError("The loaded dataset has no cache files.")
            return os.path.dirname(dataset.cache_files[-1]["filename"])
        def _get_cache_file_names() -> dict[str, str | None]:
            return {
                split_name: os.path.join(
                    _get_cache_dir(dataset_split), "cache", f"{split_name}.map"
                )
                for split_name, dataset_split in dataset.items()
            }
        return (
            dataset.remove_columns(["width", "height", "objects"])
            .with_format("torch")
            .map(
                resize_map,
                batched=False,
                load_from_cache_file=True,
                cache_file_names=_get_cache_file_names(),
                num_proc=self.dataset_hyperparameters.num_proc,
            )
        )
    @override
    def prepare_data(self) -> None:
        """Download and prepare the dataset splits.
        This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
        """
        _ = self._prepare_data()
    @override
    def setup(self, stage: str) -> None:
        """Set up the datasets.
        Args:
            stage: Whether we are fitting (training), testing or predicting.
        """
        dataset = self._prepare_data()
        self.train_dataset = dataset["train"]
        self.val_dataset = dataset["val"]
    @override
    def train_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
        """Create the training dataloader.
        Returns:
            The training dataloader.
        """
        return torch.utils.data.DataLoader(
            self.train_dataset,  # pyright: ignore[reportArgumentType]
            **dataclasses.asdict(self.train_dataloader_hyperparameters),
            shuffle=True,
            pin_memory=True,
        )
    @override
    def val_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
        """Create the training dataloader.
        Returns:
            The training dataloader.
        """
        return torch.utils.data.DataLoader(
            self.val_dataset,  # pyright: ignore[reportArgumentType]
            **dataclasses.asdict(self.test_dataloader_hyperparameters),
            shuffle=False,
            pin_memory=True,
        )
In [ ]:
Copied!
class StainedGlassTensorflowVisionModule(lightning.LightningModule):
    def __init__(
        self,
        transform: TransformHyperparameters,
        detection: DetectionHyperparameters,
        loss: LossHyperparameters,
        transform_loss: TransformLossHyperparameters,
        optimizer: OptimizerHyperparameters,
        metric: MetricHyperparameters,
    ) -> None:
        super().__init__()
        self.transform_hyperparameters = transform
        self.detection_hyperparameters = detection
        self.loss_hyperparameters = loss
        self.transform_loss_hyperparameters = transform_loss
        self.optimizer_hyperparameters = optimizer
        self.metric_hyperparameters = metric
    @functools.cached_property
    def image_size(self) -> tuple[int, int]:
        _, height, width, color_channels = (
            self.noisy_model.base_model.tf_model.input_shape
        )
        assert (
            isinstance(height, int)
            and isinstance(width, int)
            and isinstance(color_channels, int)
        )
        assert color_channels in (1, 3)
        return height, width
    @property
    def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None:  # pyright: ignore[reportPrivateImportUsage]
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
                return logger.experiment
        return None
    @override
    def configure_model(self) -> None:
        """Configure the model assuming a non-ddp parallelism."""
        self.detection_model = ultralytics.YOLO(task="detect")
        # patch ultralytics.engine.model.Model.train with nn.Module.train
        self.detection_model.train = functools.partial(  # pyright: ignore[reportAttributeAccessIssue]
            torch.nn.Module.train,
            self.detection_model,
        )
        with tf.device(
            "/device:GPU:0" if torch.cuda.is_available() else "/CPU:0"
        ):
            base_model_tf = keras.applications.EfficientNetB0(
                include_top=True,
                weights="imagenet",
                input_tensor=None,
                input_shape=None,
                pooling=None,
                classes=1000,
                classifier_activation=None,  # pyright: ignore[reportArgumentType]
            )
        self.classifier = sg_tensorflow.VisionTensorFlowAdapter(base_model_tf)
        self.noisy_model = sg_model.NoisyModel(
            sg_noise_layer.CloakNoiseLayerOneShot,
            self.classifier,
            target_parameter="input",
            **dataclasses.asdict(self.transform_hyperparameters),
        )
        self.task_loss = nn.CrossEntropyLoss(
            **dataclasses.asdict(self.loss_hyperparameters),
        )
        self.loss_function, self.get_loss_components, _ = (
            sg_cloak_loss.composite_cloak_loss_factory(
                self.noisy_model,
                loss_function=self.task_loss,
                **dataclasses.asdict(self.transform_loss_hyperparameters),
            )
        )
        self.task_loss_name = (
            self.task_loss.__name__
            if isinstance(self.task_loss, types.FunctionType)
            else re.sub(
                r"(?<!^)(?=[A-Z])", "_", type(self.task_loss).__name__
            ).lower()
        )
        self.resizer = torchvision.transforms.v2.Resize(self.image_size)
        self.configure_metrics()
    def configure_metrics(self) -> None:
        """Register the metrics used during training and validation."""
        with (
            self.device,
            sg_utils.torch.dtypes.default_dtype(torch.float32),
        ):
            self._train_metrics = TrainingStepMetrics()
            self._val_metrics = TrainingStepMetrics()
        # we cannot use ModuleDict here because the 'train' key conflicts with the 'Module.train' method, also TypedDict is based
        self.metrics: Metrics = {
            "train": self._train_metrics,
            "valid": self._val_metrics,
        }
    @override
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(
            self.noisy_model.noise_layer.parameters(),
            **dataclasses.asdict(self.optimizer_hyperparameters),
            weight_decay=0.0,  # nonzero weight_decay for SGT parameters hinders training
        )
    @override
    def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
        return self._step(batch, batch_idx, "train")
    @override
    def validation_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
        return self._step(batch, batch_idx, "valid")
    @override
    def on_train_epoch_end(self) -> None:
        return self._on_epoch_end("train")
    @override
    def on_validation_epoch_end(self) -> None:
        return self._on_epoch_end("valid")
    def _crop_batch_to_yolo_bounding_boxes(
        self,
        yolo_results: list[ultralytics.engine.results.Results],
        image_batch: torch.Tensor,
    ) -> torch.Tensor:
        """Crop a batch of images according to the yolo detection results, returning a batch of images of the original size."""
        assert len(yolo_results) == len(image_batch)
        cropped_images: list[torch.Tensor] = []
        for yolo_result, full_image in zip(yolo_results, image_batch):
            if yolo_result.boxes is None or len(yolo_result.boxes.cls) == 0:
                cropped_images.append(full_image)
                continue
            # filter out the "person" class because there are human categories in ImageNet
            non_person_boxes = [
                box
                for cls, box in zip(
                    yolo_result.boxes.cls, yolo_result.boxes.xywh
                )
                if "person" != yolo_result.names[cls.item()]
            ]
            if not non_person_boxes:
                cropped_images.append(full_image)
                continue
            for x_center, y_center, width, height in random.choices(
                non_person_boxes, k=self.detection_hyperparameters.num_crops
            ):
                cropped_image = torchvision.transforms.v2.functional.crop(
                    full_image,
                    top=int((y_center - height / 2).item()),
                    left=int((x_center - width / 2).item()),
                    height=int(height.item()),
                    width=int(width.item()),
                )
                cropped_image = self.resizer(cropped_image)
                cropped_images.append(cropped_image)
        return torch.stack(cropped_images)
    def _step(
        self, batch: DataPoint, batch_idx: int, mode: Literal["train", "valid"]
    ) -> torch.Tensor:
        raw_images = batch["image"]
        with torch.no_grad():
            detection_results = self.detection_model.predict(
                raw_images, verbose=False
            )
            cropped_raw_images = self._crop_batch_to_yolo_bounding_boxes(
                detection_results, raw_images
            )
            raw_logits = self.classifier(cropped_raw_images)
            # use the model's maximum likelihood logit as the label
            labels = raw_logits.argmax(dim=1)
        transformed_images = self.noisy_model.noise_layer(raw_images)
        cropped_transformed_images = self._crop_batch_to_yolo_bounding_boxes(
            detection_results, transformed_images
        )
        transformed_logits = self.classifier(cropped_transformed_images)
        preds = transformed_logits.argmax(dim=1)
        loss = self.loss_function(transformed_logits, target=labels)
        num_batches = (
            self.trainer.num_training_batches
            if self.trainer.num_training_batches != math.inf
            else 0
        )
        global_step = int(self.current_epoch * num_batches + batch_idx)
        metrics = {
            "accuracy": self.metrics[mode].accuracy(preds, labels),
            "precision": self.metrics[mode].precision(preds, labels),
            **{
                self.task_loss_name if name == "task_loss" else name: value
                for name, value in self.get_loss_components().items()
            },
        }
        self.log_dict(metrics, prog_bar=True, logger=False)
        if self.logger is not None:
            self.logger.log_metrics(
                {
                    f"{name}/{mode}/batch": value
                    for name, value in metrics.items()
                },
                step=global_step,
            )
        def side_by_side_grid(
            raw: torch.Tensor, transformed: torch.Tensor
        ) -> torch.Tensor:
            side_by_side = torch.stack([raw, transformed], dim=1).reshape(
                -1, *raw.shape[1:]
            )
            return torchvision.utils.make_grid(side_by_side)
        side_by_side_crops_grid = side_by_side_grid(
            cropped_raw_images, cropped_transformed_images
        )
        side_by_side_full_grid = side_by_side_grid(
            raw_images, transformed_images
        )
        if self.tb_writer is not None:
            self.tb_writer.add_image(
                f"side_by_side_cropped/{mode}/batch",
                side_by_side_crops_grid,
                global_step=global_step,
            )
            self.tb_writer.add_image(
                f"side_by_side_full/{mode}/batch",
                side_by_side_full_grid,
                global_step=global_step,
            )
            self.tb_writer.add_histogram(
                f"std_histogram/{mode}/batch",
                self.noisy_model.noise_layer.std_estimator.module.weight.detach(),
                global_step=global_step,
                bins=512,  # pyright: ignore[reportArgumentType]
            )
            self.tb_writer.add_histogram(
                f"mean_histogram/{mode}/batch",
                self.noisy_model.noise_layer.mean_estimator.module.weight.detach(),
                global_step=global_step,
                bins=512,  # pyright: ignore[reportArgumentType]
            )
        return loss
    def _on_epoch_end(self, mode: Literal["train", "valid"]) -> None:
        """Compute epoch-level training metrics."""
        metrics: dict[str, float] = {
            "epoch": self.current_epoch,
        }
        for loss_name, metric in self.metrics[mode].mean_losses.items():
            metrics[loss_name] = metric.compute().item()
            metric.reset()
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(
                {
                    f"{name}/{mode}/epoch": value
                    for name, value in metrics.items()
                },
                step=self.current_epoch,
            )
class StainedGlassTensorflowVisionModule(lightning.LightningModule):
    def __init__(
        self,
        transform: TransformHyperparameters,
        detection: DetectionHyperparameters,
        loss: LossHyperparameters,
        transform_loss: TransformLossHyperparameters,
        optimizer: OptimizerHyperparameters,
        metric: MetricHyperparameters,
    ) -> None:
        super().__init__()
        self.transform_hyperparameters = transform
        self.detection_hyperparameters = detection
        self.loss_hyperparameters = loss
        self.transform_loss_hyperparameters = transform_loss
        self.optimizer_hyperparameters = optimizer
        self.metric_hyperparameters = metric
    @functools.cached_property
    def image_size(self) -> tuple[int, int]:
        _, height, width, color_channels = (
            self.noisy_model.base_model.tf_model.input_shape
        )
        assert (
            isinstance(height, int)
            and isinstance(width, int)
            and isinstance(color_channels, int)
        )
        assert color_channels in (1, 3)
        return height, width
    @property
    def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None:  # pyright: ignore[reportPrivateImportUsage]
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
                return logger.experiment
        return None
    @override
    def configure_model(self) -> None:
        """Configure the model assuming a non-ddp parallelism."""
        self.detection_model = ultralytics.YOLO(task="detect")
        # patch ultralytics.engine.model.Model.train with nn.Module.train
        self.detection_model.train = functools.partial(  # pyright: ignore[reportAttributeAccessIssue]
            torch.nn.Module.train,
            self.detection_model,
        )
        with tf.device(
            "/device:GPU:0" if torch.cuda.is_available() else "/CPU:0"
        ):
            base_model_tf = keras.applications.EfficientNetB0(
                include_top=True,
                weights="imagenet",
                input_tensor=None,
                input_shape=None,
                pooling=None,
                classes=1000,
                classifier_activation=None,  # pyright: ignore[reportArgumentType]
            )
        self.classifier = sg_tensorflow.VisionTensorFlowAdapter(base_model_tf)
        self.noisy_model = sg_model.NoisyModel(
            sg_noise_layer.CloakNoiseLayerOneShot,
            self.classifier,
            target_parameter="input",
            **dataclasses.asdict(self.transform_hyperparameters),
        )
        self.task_loss = nn.CrossEntropyLoss(
            **dataclasses.asdict(self.loss_hyperparameters),
        )
        self.loss_function, self.get_loss_components, _ = (
            sg_cloak_loss.composite_cloak_loss_factory(
                self.noisy_model,
                loss_function=self.task_loss,
                **dataclasses.asdict(self.transform_loss_hyperparameters),
            )
        )
        self.task_loss_name = (
            self.task_loss.__name__
            if isinstance(self.task_loss, types.FunctionType)
            else re.sub(
                r"(? None:
        """Register the metrics used during training and validation."""
        with (
            self.device,
            sg_utils.torch.dtypes.default_dtype(torch.float32),
        ):
            self._train_metrics = TrainingStepMetrics()
            self._val_metrics = TrainingStepMetrics()
        # we cannot use ModuleDict here because the 'train' key conflicts with the 'Module.train' method, also TypedDict is based
        self.metrics: Metrics = {
            "train": self._train_metrics,
            "valid": self._val_metrics,
        }
    @override
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(
            self.noisy_model.noise_layer.parameters(),
            **dataclasses.asdict(self.optimizer_hyperparameters),
            weight_decay=0.0,  # nonzero weight_decay for SGT parameters hinders training
        )
    @override
    def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
        return self._step(batch, batch_idx, "train")
    @override
    def validation_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
        return self._step(batch, batch_idx, "valid")
    @override
    def on_train_epoch_end(self) -> None:
        return self._on_epoch_end("train")
    @override
    def on_validation_epoch_end(self) -> None:
        return self._on_epoch_end("valid")
    def _crop_batch_to_yolo_bounding_boxes(
        self,
        yolo_results: list[ultralytics.engine.results.Results],
        image_batch: torch.Tensor,
    ) -> torch.Tensor:
        """Crop a batch of images according to the yolo detection results, returning a batch of images of the original size."""
        assert len(yolo_results) == len(image_batch)
        cropped_images: list[torch.Tensor] = []
        for yolo_result, full_image in zip(yolo_results, image_batch):
            if yolo_result.boxes is None or len(yolo_result.boxes.cls) == 0:
                cropped_images.append(full_image)
                continue
            # filter out the "person" class because there are human categories in ImageNet
            non_person_boxes = [
                box
                for cls, box in zip(
                    yolo_result.boxes.cls, yolo_result.boxes.xywh
                )
                if "person" != yolo_result.names[cls.item()]
            ]
            if not non_person_boxes:
                cropped_images.append(full_image)
                continue
            for x_center, y_center, width, height in random.choices(
                non_person_boxes, k=self.detection_hyperparameters.num_crops
            ):
                cropped_image = torchvision.transforms.v2.functional.crop(
                    full_image,
                    top=int((y_center - height / 2).item()),
                    left=int((x_center - width / 2).item()),
                    height=int(height.item()),
                    width=int(width.item()),
                )
                cropped_image = self.resizer(cropped_image)
                cropped_images.append(cropped_image)
        return torch.stack(cropped_images)
    def _step(
        self, batch: DataPoint, batch_idx: int, mode: Literal["train", "valid"]
    ) -> torch.Tensor:
        raw_images = batch["image"]
        with torch.no_grad():
            detection_results = self.detection_model.predict(
                raw_images, verbose=False
            )
            cropped_raw_images = self._crop_batch_to_yolo_bounding_boxes(
                detection_results, raw_images
            )
            raw_logits = self.classifier(cropped_raw_images)
            # use the model's maximum likelihood logit as the label
            labels = raw_logits.argmax(dim=1)
        transformed_images = self.noisy_model.noise_layer(raw_images)
        cropped_transformed_images = self._crop_batch_to_yolo_bounding_boxes(
            detection_results, transformed_images
        )
        transformed_logits = self.classifier(cropped_transformed_images)
        preds = transformed_logits.argmax(dim=1)
        loss = self.loss_function(transformed_logits, target=labels)
        num_batches = (
            self.trainer.num_training_batches
            if self.trainer.num_training_batches != math.inf
            else 0
        )
        global_step = int(self.current_epoch * num_batches + batch_idx)
        metrics = {
            "accuracy": self.metrics[mode].accuracy(preds, labels),
            "precision": self.metrics[mode].precision(preds, labels),
            **{
                self.task_loss_name if name == "task_loss" else name: value
                for name, value in self.get_loss_components().items()
            },
        }
        self.log_dict(metrics, prog_bar=True, logger=False)
        if self.logger is not None:
            self.logger.log_metrics(
                {
                    f"{name}/{mode}/batch": value
                    for name, value in metrics.items()
                },
                step=global_step,
            )
        def side_by_side_grid(
            raw: torch.Tensor, transformed: torch.Tensor
        ) -> torch.Tensor:
            side_by_side = torch.stack([raw, transformed], dim=1).reshape(
                -1, *raw.shape[1:]
            )
            return torchvision.utils.make_grid(side_by_side)
        side_by_side_crops_grid = side_by_side_grid(
            cropped_raw_images, cropped_transformed_images
        )
        side_by_side_full_grid = side_by_side_grid(
            raw_images, transformed_images
        )
        if self.tb_writer is not None:
            self.tb_writer.add_image(
                f"side_by_side_cropped/{mode}/batch",
                side_by_side_crops_grid,
                global_step=global_step,
            )
            self.tb_writer.add_image(
                f"side_by_side_full/{mode}/batch",
                side_by_side_full_grid,
                global_step=global_step,
            )
            self.tb_writer.add_histogram(
                f"std_histogram/{mode}/batch",
                self.noisy_model.noise_layer.std_estimator.module.weight.detach(),
                global_step=global_step,
                bins=512,  # pyright: ignore[reportArgumentType]
            )
            self.tb_writer.add_histogram(
                f"mean_histogram/{mode}/batch",
                self.noisy_model.noise_layer.mean_estimator.module.weight.detach(),
                global_step=global_step,
                bins=512,  # pyright: ignore[reportArgumentType]
            )
        return loss
    def _on_epoch_end(self, mode: Literal["train", "valid"]) -> None:
        """Compute epoch-level training metrics."""
        metrics: dict[str, float] = {
            "epoch": self.current_epoch,
        }
        for loss_name, metric in self.metrics[mode].mean_losses.items():
            metrics[loss_name] = metric.compute().item()
            metric.reset()
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(
                {
                    f"{name}/{mode}/epoch": value
                    for name, value in metrics.items()
                },
                step=self.current_epoch,
            )
In [ ]:
Copied!
torch.set_float32_matmul_precision("high")
logger = lightning.pytorch.loggers.TensorBoardLogger(
    "tensorflow_efficientnet",
    name="tensorflow_efficientnet_fullcoco",
)
trainer = lightning.Trainer(
    logger=logger,
    devices=[0],
    max_epochs=50,
    log_every_n_steps=30,
    fast_dev_run=True,  # only enable in CI
)
with trainer.init_module():
    training_module = StainedGlassTensorflowVisionModule(
        transform=TransformHyperparameters(),
        detection=DetectionHyperparameters(),
        loss=LossHyperparameters(),
        transform_loss=TransformLossHyperparameters(),
        optimizer=OptimizerHyperparameters(),
        metric=MetricHyperparameters(),
    )
    # we must instantiate the model before we can know what size to reshape the images in the dataset to
    training_module.configure_model()
data_module = HuggingFaceDataModule(
    training_module.image_size,
    dataset=DatasetHyperparameters(path="mini_coco", num_proc=None),
    train_dataloader=DataLoaderHyperparameters(),
    test_dataloader=DataLoaderHyperparameters(),
)
trainer.fit(model=training_module, datamodule=data_module)
torch.set_float32_matmul_precision("high")
logger = lightning.pytorch.loggers.TensorBoardLogger(
    "tensorflow_efficientnet",
    name="tensorflow_efficientnet_fullcoco",
)
trainer = lightning.Trainer(
    logger=logger,
    devices=[0],
    max_epochs=50,
    log_every_n_steps=30,
    fast_dev_run=True,  # only enable in CI
)
with trainer.init_module():
    training_module = StainedGlassTensorflowVisionModule(
        transform=TransformHyperparameters(),
        detection=DetectionHyperparameters(),
        loss=LossHyperparameters(),
        transform_loss=TransformLossHyperparameters(),
        optimizer=OptimizerHyperparameters(),
        metric=MetricHyperparameters(),
    )
    # we must instantiate the model before we can know what size to reshape the images in the dataset to
    training_module.configure_model()
data_module = HuggingFaceDataModule(
    training_module.image_size,
    dataset=DatasetHyperparameters(path="mini_coco", num_proc=None),
    train_dataloader=DataLoaderHyperparameters(),
    test_dataloader=DataLoaderHyperparameters(),
)
trainer.fit(model=training_module, datamodule=data_module)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed. 2025-03-14 12:58:56.178723: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 70576 MB memory: -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:18:00.0, compute capability: 9.0 Map (num_proc=2): 0%| | 0/64 [00:00<?, ? examples/s]