TensorFlow Adapter¶
The notebook demonstrates how to use Stained Glass Core to create a Stained Glass Transform for a TensorFlow model using TensorFlowAdapter
, a PyTorch-based TensorFlow model wrapper.
Installation¶
Finding versions of PyTorch and TensorFlow that were compiled with compatible CUDA versions is a difficult. These versions have been verified to be compatible:
torch~=2.6.0
and tensorflow[and-cuda]~=2.17.0
:
pip install uv
uv pip install "tensorflow>=2.17.0,<2.18.0" ultralytics
uv pip install -e .[all,torch-2-6]
uv pip install "tensorflow[and-cuda]>=2.17.0,<2.18.0"
torch~=2.5.0
and tensorflow[and-cuda]~=2.16.0
:
In [1]:
Copied!
from __future__ import annotations
import functools
import math
import os
import random
import re
import types
from typing import Literal
import datasets
import keras
import lightning
import lightning.pytorch.loggers
import tensorflow as tf
import torch
import torch.utils.tensorboard
import torchmetrics
import torchvision.transforms.v2
import ultralytics
import ultralytics.engine.results
from torch import nn
from typing_extensions import override
from stainedglass_core import (
model as sg_model,
noise_layer as sg_noise_layer,
utils as sg_utils,
)
from stainedglass_core.integrations import tensorflow as sg_tensorflow
from stainedglass_core.loss import cloak as sg_cloak_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices("GPU")
print(physical_devices)
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
tf.config.experimental.set_memory_growth(physical_devices[0], True)
from __future__ import annotations
import functools
import math
import os
import random
import re
import types
from typing import Literal
import datasets
import keras
import lightning
import lightning.pytorch.loggers
import tensorflow as tf
import torch
import torch.utils.tensorboard
import torchmetrics
import torchvision.transforms.v2
import ultralytics
import ultralytics.engine.results
from torch import nn
from typing_extensions import override
from stainedglass_core import (
model as sg_model,
noise_layer as sg_noise_layer,
utils as sg_utils,
)
from stainedglass_core.integrations import tensorflow as sg_tensorflow
from stainedglass_core.loss import cloak as sg_cloak_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices("GPU")
print(physical_devices)
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
tf.config.experimental.set_memory_growth(physical_devices[0], True)
/home/matthew/.conda/envs/core39/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm 2025-03-14 12:58:49.115331: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2025-03-14 12:58:49.135288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2025-03-14 12:58:49.157477: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2025-03-14 12:58:49.163992: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-03-14 12:58:49.180946: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2025-03-14 12:58:50.047310: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')] Num GPUs Available: 1
In [2]:
Copied!
from typing import TypedDict
class DataPoint(TypedDict):
image: torch.Tensor
class BatchedDataPoint(TypedDict):
image: list[torch.Tensor]
from typing import TypedDict
class DataPoint(TypedDict):
image: torch.Tensor
class BatchedDataPoint(TypedDict):
image: list[torch.Tensor]
In [ ]:
Copied!
import dataclasses
import inspect
from typing import Any
from typing_extensions import Self
@dataclasses.dataclass
class Hyperparameters:
@classmethod
def from_dict(cls, **kwargs: Any) -> Self:
cls_parameters = inspect.signature(cls).parameters
instance = cls(
**{k: kwargs.pop(k) for k in list(kwargs) if k in cls_parameters}
)
for k, v in kwargs.items():
setattr(instance, k, v)
return instance
@dataclasses.dataclass
class TransformHyperparameters(Hyperparameters):
percent_to_mask: float = 0.0
scale: tuple[float, float] = (1e-4, 1.0)
shallow: float = 1.0
rhos_init: float = -4.0
seed: int | None = None
@dataclasses.dataclass
class LossHyperparameters(Hyperparameters):
reduction: str = "mean"
@dataclasses.dataclass
class TransformLossHyperparameters(Hyperparameters):
alpha: float = 0.5
@dataclasses.dataclass
class OptimizerHyperparameters(Hyperparameters):
lr: float = 3e-3
fused: bool = True
@dataclasses.dataclass
class MetricHyperparameters(Hyperparameters): ...
@dataclasses.dataclass
class DatasetHyperparameters(Hyperparameters):
path: str = "detection-datasets/coco"
"""The name of a dataset on https://huggingface.co/datasets or the path to a local dataset."""
cache_dir: str | None = "/data_fast"
"""The directory to store the downloaded dataset in."""
num_proc: int | None = 2
"""The number of workers to use when resizing the dataset."""
@dataclasses.dataclass
class DataLoaderHyperparameters(Hyperparameters):
batch_size: int = 24
num_workers: int | None = 0
@dataclasses.dataclass
class DetectionHyperparameters(Hyperparameters):
num_crops: int = 4
import dataclasses
import inspect
from typing import Any
from typing_extensions import Self
@dataclasses.dataclass
class Hyperparameters:
@classmethod
def from_dict(cls, **kwargs: Any) -> Self:
cls_parameters = inspect.signature(cls).parameters
instance = cls(
**{k: kwargs.pop(k) for k in list(kwargs) if k in cls_parameters}
)
for k, v in kwargs.items():
setattr(instance, k, v)
return instance
@dataclasses.dataclass
class TransformHyperparameters(Hyperparameters):
percent_to_mask: float = 0.0
scale: tuple[float, float] = (1e-4, 1.0)
shallow: float = 1.0
rhos_init: float = -4.0
seed: int | None = None
@dataclasses.dataclass
class LossHyperparameters(Hyperparameters):
reduction: str = "mean"
@dataclasses.dataclass
class TransformLossHyperparameters(Hyperparameters):
alpha: float = 0.5
@dataclasses.dataclass
class OptimizerHyperparameters(Hyperparameters):
lr: float = 3e-3
fused: bool = True
@dataclasses.dataclass
class MetricHyperparameters(Hyperparameters): ...
@dataclasses.dataclass
class DatasetHyperparameters(Hyperparameters):
path: str = "detection-datasets/coco"
"""The name of a dataset on https://huggingface.co/datasets or the path to a local dataset."""
cache_dir: str | None = "/data_fast"
"""The directory to store the downloaded dataset in."""
num_proc: int | None = 2
"""The number of workers to use when resizing the dataset."""
@dataclasses.dataclass
class DataLoaderHyperparameters(Hyperparameters):
batch_size: int = 24
num_workers: int | None = 0
@dataclasses.dataclass
class DetectionHyperparameters(Hyperparameters):
num_crops: int = 4
In [4]:
Copied!
from typing_extensions import TypedDict
from stainedglass_core.utils.torch import nn as sg_nn
class Metrics(TypedDict):
"""Models the metrics collected in the different modes of `StainedGlassTensorflowVisionModule`."""
train: TrainingStepMetrics
valid: TrainingStepMetrics
class TrainingStepMetrics(nn.Module):
"""Common metrics to be collected during `_step` in both training and validation."""
def __init__(self) -> None:
super().__init__()
self.mean_losses = sg_nn.ModuleDefaultDict(torchmetrics.MeanMetric)
self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=1000)
self.precision = torchmetrics.Precision("multiclass", num_classes=1000)
from typing_extensions import TypedDict
from stainedglass_core.utils.torch import nn as sg_nn
class Metrics(TypedDict):
"""Models the metrics collected in the different modes of `StainedGlassTensorflowVisionModule`."""
train: TrainingStepMetrics
valid: TrainingStepMetrics
class TrainingStepMetrics(nn.Module):
"""Common metrics to be collected during `_step` in both training and validation."""
def __init__(self) -> None:
super().__init__()
self.mean_losses = sg_nn.ModuleDefaultDict(torchmetrics.MeanMetric)
self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=1000)
self.precision = torchmetrics.Precision("multiclass", num_classes=1000)
In [ ]:
Copied!
class HuggingFaceDataModule(lightning.LightningDataModule):
def __init__(
self,
image_size: tuple[int, int],
dataset: DatasetHyperparameters,
train_dataloader: DataLoaderHyperparameters,
test_dataloader: DataLoaderHyperparameters,
):
super().__init__()
self.image_size = image_size
self.dataset_hyperparameters = dataset
self.train_dataloader_hyperparameters = train_dataloader
self.test_dataloader_hyperparameters = test_dataloader
def _prepare_data(self) -> datasets.DatasetDict:
"""Download and prepare the dataset splits."""
resize_transform = torchvision.transforms.v2.Compose(
[
torchvision.transforms.v2.Resize(self.image_size),
torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
torchvision.transforms.v2.RGB(),
]
)
def resize_map(
sample: DataPoint | BatchedDataPoint,
) -> DataPoint | BatchedDataPoint:
if isinstance(sample["image"], list):
return {
"image": [
resize_transform(image) for image in sample["image"]
]
}
return {"image": resize_transform(sample["image"])}
if os.path.isdir(path := self.dataset_hyperparameters.path) or (
"__file__" in globals()
and os.path.isdir(
path := os.path.join(os.path.dirname(__file__), path)
)
):
dataset = datasets.load_from_disk(path)
else:
dataset = datasets.load_dataset(
self.dataset_hyperparameters.path,
cache_dir=self.dataset_hyperparameters.cache_dir,
)
assert isinstance(dataset, datasets.DatasetDict)
def _get_cache_dir(dataset: datasets.Dataset) -> str:
if not dataset.cache_files:
raise ValueError("The loaded dataset has no cache files.")
return os.path.dirname(dataset.cache_files[-1]["filename"])
def _get_cache_file_names() -> dict[str, str | None]:
return {
split_name: os.path.join(
_get_cache_dir(dataset_split), "cache", f"{split_name}.map"
)
for split_name, dataset_split in dataset.items()
}
return (
dataset.remove_columns(["width", "height", "objects"])
.with_format("torch")
.map(
resize_map,
batched=False,
load_from_cache_file=True,
cache_file_names=_get_cache_file_names(),
num_proc=self.dataset_hyperparameters.num_proc,
)
)
@override
def prepare_data(self) -> None:
"""Download and prepare the dataset splits.
This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
"""
_ = self._prepare_data()
@override
def setup(self, stage: str) -> None:
"""Set up the datasets.
Args:
stage: Whether we are fitting (training), testing or predicting.
"""
dataset = self._prepare_data()
self.train_dataset = dataset["train"]
self.val_dataset = dataset["val"]
@override
def train_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.train_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.train_dataloader_hyperparameters),
shuffle=True,
pin_memory=True,
)
@override
def val_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.val_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.test_dataloader_hyperparameters),
shuffle=False,
pin_memory=True,
)
class HuggingFaceDataModule(lightning.LightningDataModule):
def __init__(
self,
image_size: tuple[int, int],
dataset: DatasetHyperparameters,
train_dataloader: DataLoaderHyperparameters,
test_dataloader: DataLoaderHyperparameters,
):
super().__init__()
self.image_size = image_size
self.dataset_hyperparameters = dataset
self.train_dataloader_hyperparameters = train_dataloader
self.test_dataloader_hyperparameters = test_dataloader
def _prepare_data(self) -> datasets.DatasetDict:
"""Download and prepare the dataset splits."""
resize_transform = torchvision.transforms.v2.Compose(
[
torchvision.transforms.v2.Resize(self.image_size),
torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
torchvision.transforms.v2.RGB(),
]
)
def resize_map(
sample: DataPoint | BatchedDataPoint,
) -> DataPoint | BatchedDataPoint:
if isinstance(sample["image"], list):
return {
"image": [
resize_transform(image) for image in sample["image"]
]
}
return {"image": resize_transform(sample["image"])}
if os.path.isdir(path := self.dataset_hyperparameters.path) or (
"__file__" in globals()
and os.path.isdir(
path := os.path.join(os.path.dirname(__file__), path)
)
):
dataset = datasets.load_from_disk(path)
else:
dataset = datasets.load_dataset(
self.dataset_hyperparameters.path,
cache_dir=self.dataset_hyperparameters.cache_dir,
)
assert isinstance(dataset, datasets.DatasetDict)
def _get_cache_dir(dataset: datasets.Dataset) -> str:
if not dataset.cache_files:
raise ValueError("The loaded dataset has no cache files.")
return os.path.dirname(dataset.cache_files[-1]["filename"])
def _get_cache_file_names() -> dict[str, str | None]:
return {
split_name: os.path.join(
_get_cache_dir(dataset_split), "cache", f"{split_name}.map"
)
for split_name, dataset_split in dataset.items()
}
return (
dataset.remove_columns(["width", "height", "objects"])
.with_format("torch")
.map(
resize_map,
batched=False,
load_from_cache_file=True,
cache_file_names=_get_cache_file_names(),
num_proc=self.dataset_hyperparameters.num_proc,
)
)
@override
def prepare_data(self) -> None:
"""Download and prepare the dataset splits.
This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
"""
_ = self._prepare_data()
@override
def setup(self, stage: str) -> None:
"""Set up the datasets.
Args:
stage: Whether we are fitting (training), testing or predicting.
"""
dataset = self._prepare_data()
self.train_dataset = dataset["train"]
self.val_dataset = dataset["val"]
@override
def train_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.train_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.train_dataloader_hyperparameters),
shuffle=True,
pin_memory=True,
)
@override
def val_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.val_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.test_dataloader_hyperparameters),
shuffle=False,
pin_memory=True,
)
In [ ]:
Copied!
class StainedGlassTensorflowVisionModule(lightning.LightningModule):
def __init__(
self,
transform: TransformHyperparameters,
detection: DetectionHyperparameters,
loss: LossHyperparameters,
transform_loss: TransformLossHyperparameters,
optimizer: OptimizerHyperparameters,
metric: MetricHyperparameters,
) -> None:
super().__init__()
self.transform_hyperparameters = transform
self.detection_hyperparameters = detection
self.loss_hyperparameters = loss
self.transform_loss_hyperparameters = transform_loss
self.optimizer_hyperparameters = optimizer
self.metric_hyperparameters = metric
@functools.cached_property
def image_size(self) -> tuple[int, int]:
_, height, width, color_channels = (
self.noisy_model.base_model.tf_model.input_shape
)
assert (
isinstance(height, int)
and isinstance(width, int)
and isinstance(color_channels, int)
)
assert color_channels in (1, 3)
return height, width
@property
def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None: # pyright: ignore[reportPrivateImportUsage]
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
return logger.experiment
return None
@override
def configure_model(self) -> None:
"""Configure the model assuming a non-ddp parallelism."""
self.detection_model = ultralytics.YOLO(task="detect")
# patch ultralytics.engine.model.Model.train with nn.Module.train
self.detection_model.train = functools.partial( # pyright: ignore[reportAttributeAccessIssue]
torch.nn.Module.train,
self.detection_model,
)
with tf.device(
"/device:GPU:0" if torch.cuda.is_available() else "/CPU:0"
):
base_model_tf = keras.applications.EfficientNetB0(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation=None, # pyright: ignore[reportArgumentType]
)
self.classifier = sg_tensorflow.VisionTensorFlowAdapter(base_model_tf)
self.noisy_model = sg_model.NoisyModel(
sg_noise_layer.CloakNoiseLayerOneShot,
self.classifier,
target_parameter="input",
**dataclasses.asdict(self.transform_hyperparameters),
)
self.task_loss = nn.CrossEntropyLoss(
**dataclasses.asdict(self.loss_hyperparameters),
)
self.loss_function, self.get_loss_components, _ = (
sg_cloak_loss.composite_cloak_loss_factory(
self.noisy_model,
loss_function=self.task_loss,
**dataclasses.asdict(self.transform_loss_hyperparameters),
)
)
self.task_loss_name = (
self.task_loss.__name__
if isinstance(self.task_loss, types.FunctionType)
else re.sub(
r"(?<!^)(?=[A-Z])", "_", type(self.task_loss).__name__
).lower()
)
self.resizer = torchvision.transforms.v2.Resize(self.image_size)
self.configure_metrics()
def configure_metrics(self) -> None:
"""Register the metrics used during training and validation."""
with (
self.device,
sg_utils.torch.dtypes.default_dtype(torch.float32),
):
self._train_metrics = TrainingStepMetrics()
self._val_metrics = TrainingStepMetrics()
# we cannot use ModuleDict here because the 'train' key conflicts with the 'Module.train' method, also TypedDict is based
self.metrics: Metrics = {
"train": self._train_metrics,
"valid": self._val_metrics,
}
@override
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.AdamW(
self.noisy_model.noise_layer.parameters(),
**dataclasses.asdict(self.optimizer_hyperparameters),
weight_decay=0.0, # nonzero weight_decay for SGT parameters hinders training
)
@override
def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "train")
@override
def validation_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "valid")
@override
def on_train_epoch_end(self) -> None:
return self._on_epoch_end("train")
@override
def on_validation_epoch_end(self) -> None:
return self._on_epoch_end("valid")
def _crop_batch_to_yolo_bounding_boxes(
self,
yolo_results: list[ultralytics.engine.results.Results],
image_batch: torch.Tensor,
) -> torch.Tensor:
"""Crop a batch of images according to the yolo detection results, returning a batch of images of the original size."""
assert len(yolo_results) == len(image_batch)
cropped_images: list[torch.Tensor] = []
for yolo_result, full_image in zip(yolo_results, image_batch):
if yolo_result.boxes is None or len(yolo_result.boxes.cls) == 0:
cropped_images.append(full_image)
continue
# filter out the "person" class because there are human categories in ImageNet
non_person_boxes = [
box
for cls, box in zip(
yolo_result.boxes.cls, yolo_result.boxes.xywh
)
if "person" != yolo_result.names[cls.item()]
]
if not non_person_boxes:
cropped_images.append(full_image)
continue
for x_center, y_center, width, height in random.choices(
non_person_boxes, k=self.detection_hyperparameters.num_crops
):
cropped_image = torchvision.transforms.v2.functional.crop(
full_image,
top=int((y_center - height / 2).item()),
left=int((x_center - width / 2).item()),
height=int(height.item()),
width=int(width.item()),
)
cropped_image = self.resizer(cropped_image)
cropped_images.append(cropped_image)
return torch.stack(cropped_images)
def _step(
self, batch: DataPoint, batch_idx: int, mode: Literal["train", "valid"]
) -> torch.Tensor:
raw_images = batch["image"]
with torch.no_grad():
detection_results = self.detection_model.predict(
raw_images, verbose=False
)
cropped_raw_images = self._crop_batch_to_yolo_bounding_boxes(
detection_results, raw_images
)
raw_logits = self.classifier(cropped_raw_images)
# use the model's maximum likelihood logit as the label
labels = raw_logits.argmax(dim=1)
transformed_images = self.noisy_model.noise_layer(raw_images)
cropped_transformed_images = self._crop_batch_to_yolo_bounding_boxes(
detection_results, transformed_images
)
transformed_logits = self.classifier(cropped_transformed_images)
preds = transformed_logits.argmax(dim=1)
loss = self.loss_function(transformed_logits, target=labels)
num_batches = (
self.trainer.num_training_batches
if self.trainer.num_training_batches != math.inf
else 0
)
global_step = int(self.current_epoch * num_batches + batch_idx)
metrics = {
"accuracy": self.metrics[mode].accuracy(preds, labels),
"precision": self.metrics[mode].precision(preds, labels),
**{
self.task_loss_name if name == "task_loss" else name: value
for name, value in self.get_loss_components().items()
},
}
self.log_dict(metrics, prog_bar=True, logger=False)
if self.logger is not None:
self.logger.log_metrics(
{
f"{name}/{mode}/batch": value
for name, value in metrics.items()
},
step=global_step,
)
def side_by_side_grid(
raw: torch.Tensor, transformed: torch.Tensor
) -> torch.Tensor:
side_by_side = torch.stack([raw, transformed], dim=1).reshape(
-1, *raw.shape[1:]
)
return torchvision.utils.make_grid(side_by_side)
side_by_side_crops_grid = side_by_side_grid(
cropped_raw_images, cropped_transformed_images
)
side_by_side_full_grid = side_by_side_grid(
raw_images, transformed_images
)
if self.tb_writer is not None:
self.tb_writer.add_image(
f"side_by_side_cropped/{mode}/batch",
side_by_side_crops_grid,
global_step=global_step,
)
self.tb_writer.add_image(
f"side_by_side_full/{mode}/batch",
side_by_side_full_grid,
global_step=global_step,
)
self.tb_writer.add_histogram(
f"std_histogram/{mode}/batch",
self.noisy_model.noise_layer.std_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
self.tb_writer.add_histogram(
f"mean_histogram/{mode}/batch",
self.noisy_model.noise_layer.mean_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
return loss
def _on_epoch_end(self, mode: Literal["train", "valid"]) -> None:
"""Compute epoch-level training metrics."""
metrics: dict[str, float] = {
"epoch": self.current_epoch,
}
for loss_name, metric in self.metrics[mode].mean_losses.items():
metrics[loss_name] = metric.compute().item()
metric.reset()
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(
{
f"{name}/{mode}/epoch": value
for name, value in metrics.items()
},
step=self.current_epoch,
)
class StainedGlassTensorflowVisionModule(lightning.LightningModule):
def __init__(
self,
transform: TransformHyperparameters,
detection: DetectionHyperparameters,
loss: LossHyperparameters,
transform_loss: TransformLossHyperparameters,
optimizer: OptimizerHyperparameters,
metric: MetricHyperparameters,
) -> None:
super().__init__()
self.transform_hyperparameters = transform
self.detection_hyperparameters = detection
self.loss_hyperparameters = loss
self.transform_loss_hyperparameters = transform_loss
self.optimizer_hyperparameters = optimizer
self.metric_hyperparameters = metric
@functools.cached_property
def image_size(self) -> tuple[int, int]:
_, height, width, color_channels = (
self.noisy_model.base_model.tf_model.input_shape
)
assert (
isinstance(height, int)
and isinstance(width, int)
and isinstance(color_channels, int)
)
assert color_channels in (1, 3)
return height, width
@property
def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None: # pyright: ignore[reportPrivateImportUsage]
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
return logger.experiment
return None
@override
def configure_model(self) -> None:
"""Configure the model assuming a non-ddp parallelism."""
self.detection_model = ultralytics.YOLO(task="detect")
# patch ultralytics.engine.model.Model.train with nn.Module.train
self.detection_model.train = functools.partial( # pyright: ignore[reportAttributeAccessIssue]
torch.nn.Module.train,
self.detection_model,
)
with tf.device(
"/device:GPU:0" if torch.cuda.is_available() else "/CPU:0"
):
base_model_tf = keras.applications.EfficientNetB0(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation=None, # pyright: ignore[reportArgumentType]
)
self.classifier = sg_tensorflow.VisionTensorFlowAdapter(base_model_tf)
self.noisy_model = sg_model.NoisyModel(
sg_noise_layer.CloakNoiseLayerOneShot,
self.classifier,
target_parameter="input",
**dataclasses.asdict(self.transform_hyperparameters),
)
self.task_loss = nn.CrossEntropyLoss(
**dataclasses.asdict(self.loss_hyperparameters),
)
self.loss_function, self.get_loss_components, _ = (
sg_cloak_loss.composite_cloak_loss_factory(
self.noisy_model,
loss_function=self.task_loss,
**dataclasses.asdict(self.transform_loss_hyperparameters),
)
)
self.task_loss_name = (
self.task_loss.__name__
if isinstance(self.task_loss, types.FunctionType)
else re.sub(
r"(? None:
"""Register the metrics used during training and validation."""
with (
self.device,
sg_utils.torch.dtypes.default_dtype(torch.float32),
):
self._train_metrics = TrainingStepMetrics()
self._val_metrics = TrainingStepMetrics()
# we cannot use ModuleDict here because the 'train' key conflicts with the 'Module.train' method, also TypedDict is based
self.metrics: Metrics = {
"train": self._train_metrics,
"valid": self._val_metrics,
}
@override
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.AdamW(
self.noisy_model.noise_layer.parameters(),
**dataclasses.asdict(self.optimizer_hyperparameters),
weight_decay=0.0, # nonzero weight_decay for SGT parameters hinders training
)
@override
def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "train")
@override
def validation_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "valid")
@override
def on_train_epoch_end(self) -> None:
return self._on_epoch_end("train")
@override
def on_validation_epoch_end(self) -> None:
return self._on_epoch_end("valid")
def _crop_batch_to_yolo_bounding_boxes(
self,
yolo_results: list[ultralytics.engine.results.Results],
image_batch: torch.Tensor,
) -> torch.Tensor:
"""Crop a batch of images according to the yolo detection results, returning a batch of images of the original size."""
assert len(yolo_results) == len(image_batch)
cropped_images: list[torch.Tensor] = []
for yolo_result, full_image in zip(yolo_results, image_batch):
if yolo_result.boxes is None or len(yolo_result.boxes.cls) == 0:
cropped_images.append(full_image)
continue
# filter out the "person" class because there are human categories in ImageNet
non_person_boxes = [
box
for cls, box in zip(
yolo_result.boxes.cls, yolo_result.boxes.xywh
)
if "person" != yolo_result.names[cls.item()]
]
if not non_person_boxes:
cropped_images.append(full_image)
continue
for x_center, y_center, width, height in random.choices(
non_person_boxes, k=self.detection_hyperparameters.num_crops
):
cropped_image = torchvision.transforms.v2.functional.crop(
full_image,
top=int((y_center - height / 2).item()),
left=int((x_center - width / 2).item()),
height=int(height.item()),
width=int(width.item()),
)
cropped_image = self.resizer(cropped_image)
cropped_images.append(cropped_image)
return torch.stack(cropped_images)
def _step(
self, batch: DataPoint, batch_idx: int, mode: Literal["train", "valid"]
) -> torch.Tensor:
raw_images = batch["image"]
with torch.no_grad():
detection_results = self.detection_model.predict(
raw_images, verbose=False
)
cropped_raw_images = self._crop_batch_to_yolo_bounding_boxes(
detection_results, raw_images
)
raw_logits = self.classifier(cropped_raw_images)
# use the model's maximum likelihood logit as the label
labels = raw_logits.argmax(dim=1)
transformed_images = self.noisy_model.noise_layer(raw_images)
cropped_transformed_images = self._crop_batch_to_yolo_bounding_boxes(
detection_results, transformed_images
)
transformed_logits = self.classifier(cropped_transformed_images)
preds = transformed_logits.argmax(dim=1)
loss = self.loss_function(transformed_logits, target=labels)
num_batches = (
self.trainer.num_training_batches
if self.trainer.num_training_batches != math.inf
else 0
)
global_step = int(self.current_epoch * num_batches + batch_idx)
metrics = {
"accuracy": self.metrics[mode].accuracy(preds, labels),
"precision": self.metrics[mode].precision(preds, labels),
**{
self.task_loss_name if name == "task_loss" else name: value
for name, value in self.get_loss_components().items()
},
}
self.log_dict(metrics, prog_bar=True, logger=False)
if self.logger is not None:
self.logger.log_metrics(
{
f"{name}/{mode}/batch": value
for name, value in metrics.items()
},
step=global_step,
)
def side_by_side_grid(
raw: torch.Tensor, transformed: torch.Tensor
) -> torch.Tensor:
side_by_side = torch.stack([raw, transformed], dim=1).reshape(
-1, *raw.shape[1:]
)
return torchvision.utils.make_grid(side_by_side)
side_by_side_crops_grid = side_by_side_grid(
cropped_raw_images, cropped_transformed_images
)
side_by_side_full_grid = side_by_side_grid(
raw_images, transformed_images
)
if self.tb_writer is not None:
self.tb_writer.add_image(
f"side_by_side_cropped/{mode}/batch",
side_by_side_crops_grid,
global_step=global_step,
)
self.tb_writer.add_image(
f"side_by_side_full/{mode}/batch",
side_by_side_full_grid,
global_step=global_step,
)
self.tb_writer.add_histogram(
f"std_histogram/{mode}/batch",
self.noisy_model.noise_layer.std_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
self.tb_writer.add_histogram(
f"mean_histogram/{mode}/batch",
self.noisy_model.noise_layer.mean_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
return loss
def _on_epoch_end(self, mode: Literal["train", "valid"]) -> None:
"""Compute epoch-level training metrics."""
metrics: dict[str, float] = {
"epoch": self.current_epoch,
}
for loss_name, metric in self.metrics[mode].mean_losses.items():
metrics[loss_name] = metric.compute().item()
metric.reset()
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(
{
f"{name}/{mode}/epoch": value
for name, value in metrics.items()
},
step=self.current_epoch,
)
In [ ]:
Copied!
torch.set_float32_matmul_precision("high")
logger = lightning.pytorch.loggers.TensorBoardLogger(
"tensorflow_efficientnet",
name="tensorflow_efficientnet_fullcoco",
)
trainer = lightning.Trainer(
logger=logger,
devices=[0],
max_epochs=50,
log_every_n_steps=30,
fast_dev_run=True, # only enable in CI
)
with trainer.init_module():
training_module = StainedGlassTensorflowVisionModule(
transform=TransformHyperparameters(),
detection=DetectionHyperparameters(),
loss=LossHyperparameters(),
transform_loss=TransformLossHyperparameters(),
optimizer=OptimizerHyperparameters(),
metric=MetricHyperparameters(),
)
# we must instantiate the model before we can know what size to reshape the images in the dataset to
training_module.configure_model()
data_module = HuggingFaceDataModule(
training_module.image_size,
dataset=DatasetHyperparameters(path="mini_coco", num_proc=None),
train_dataloader=DataLoaderHyperparameters(),
test_dataloader=DataLoaderHyperparameters(),
)
trainer.fit(model=training_module, datamodule=data_module)
torch.set_float32_matmul_precision("high")
logger = lightning.pytorch.loggers.TensorBoardLogger(
"tensorflow_efficientnet",
name="tensorflow_efficientnet_fullcoco",
)
trainer = lightning.Trainer(
logger=logger,
devices=[0],
max_epochs=50,
log_every_n_steps=30,
fast_dev_run=True, # only enable in CI
)
with trainer.init_module():
training_module = StainedGlassTensorflowVisionModule(
transform=TransformHyperparameters(),
detection=DetectionHyperparameters(),
loss=LossHyperparameters(),
transform_loss=TransformLossHyperparameters(),
optimizer=OptimizerHyperparameters(),
metric=MetricHyperparameters(),
)
# we must instantiate the model before we can know what size to reshape the images in the dataset to
training_module.configure_model()
data_module = HuggingFaceDataModule(
training_module.image_size,
dataset=DatasetHyperparameters(path="mini_coco", num_proc=None),
train_dataloader=DataLoaderHyperparameters(),
test_dataloader=DataLoaderHyperparameters(),
)
trainer.fit(model=training_module, datamodule=data_module)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed. 2025-03-14 12:58:56.178723: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 70576 MB memory: -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:18:00.0, compute capability: 9.0 Map (num_proc=2): 0%| | 0/64 [00:00<?, ? examples/s]