TensorFlow Adapter¶
The notebook demonstrates how to use Stained Glass Core to create a Stained Glass Transform for a TensorFlow model using TensorFlowAdapter, a PyTorch-based TensorFlow model wrapper.
Installation¶
Finding versions of PyTorch and TensorFlow that were compiled with compatible CUDA versions is sometimes difficult. In general, PyTorch tends to be more forgiving with CUDA versions, while TensorFlow is more strict.These versions have been verified to be compatible:
torch~=2.9.0 and tensorflow[and-cuda]==2.20.0:
In [ ]:
Copied!
from __future__ import annotations
import functools
import math
import os
import pathlib
import random
import re
import types
from typing import Literal
import datasets
import keras
import lightning
import lightning.pytorch.loggers
import rfdetr
import supervision as sv
import tensorflow as tf
import torch
import torch.utils.tensorboard
import torchmetrics
import torchvision.transforms.v2
from torch import nn
from typing_extensions import override
from stainedglass_core import model as sg_model, noise_layer as sg_noise_layer, utils as sg_utils
from stainedglass_core.integrations import tensorflow as sg_tensorflow
from stainedglass_core.loss import cloak as sg_cloak_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices("GPU")
print(physical_devices)
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
tf.config.experimental.set_memory_growth(physical_devices[0], True)
from __future__ import annotations
import functools
import math
import os
import pathlib
import random
import re
import types
from typing import Literal
import datasets
import keras
import lightning
import lightning.pytorch.loggers
import rfdetr
import supervision as sv
import tensorflow as tf
import torch
import torch.utils.tensorboard
import torchmetrics
import torchvision.transforms.v2
from torch import nn
from typing_extensions import override
from stainedglass_core import model as sg_model, noise_layer as sg_noise_layer, utils as sg_utils
from stainedglass_core.integrations import tensorflow as sg_tensorflow
from stainedglass_core.loss import cloak as sg_cloak_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices("GPU")
print(physical_devices)
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
tf.config.experimental.set_memory_growth(physical_devices[0], True)
In [ ]:
Copied!
from typing import TypedDict
class DataPoint(TypedDict):
image: torch.Tensor
class BatchedDataPoint(TypedDict):
image: list[torch.Tensor]
from typing import TypedDict
class DataPoint(TypedDict):
image: torch.Tensor
class BatchedDataPoint(TypedDict):
image: list[torch.Tensor]
In [ ]:
Copied!
import dataclasses
import inspect
from typing import Any, Final
from typing_extensions import Self
NO_CACHE_FILES_ERR_MSG: Final[str] = "The loaded dataset has no cache files."
@dataclasses.dataclass
class Hyperparameters:
@classmethod
def from_dict(cls, **kwargs: Any) -> Self:
"""Create an instance from a dictionary of arguments.
Args:
**kwargs: Keyword arguments for initialization and attribute setting.
Returns:
An initialized instance of the class.
"""
cls_parameters = inspect.signature(cls).parameters
instance = cls(**{k: kwargs.pop(k) for k in list(kwargs) if k in cls_parameters})
for k, v in kwargs.items():
setattr(instance, k, v)
return instance
@dataclasses.dataclass
class TransformHyperparameters(Hyperparameters):
percent_to_mask: float = 0.0
scale: tuple[float, float] = (1e-4, 1.0)
shallow: float = 1.0
rhos_init: float = -4.0
seed: int | None = None
@dataclasses.dataclass
class LossHyperparameters(Hyperparameters):
reduction: str = "mean"
@dataclasses.dataclass
class TransformLossHyperparameters(Hyperparameters):
alpha: float = 0.5
@dataclasses.dataclass
class OptimizerHyperparameters(Hyperparameters):
lr: float = 3e-3
fused: bool = True
@dataclasses.dataclass
class MetricHyperparameters(Hyperparameters): ...
@dataclasses.dataclass
class DatasetHyperparameters(Hyperparameters):
path: str = "detection-datasets/coco"
"""The name of a dataset on https://huggingface.co/datasets or the path to a local dataset."""
cache_dir: str | None = "/data_fast"
"""The directory to store the downloaded dataset in."""
num_proc: int | None = 2
"""The number of workers to use when resizing the dataset."""
@dataclasses.dataclass
class DataLoaderHyperparameters(Hyperparameters):
batch_size: int = 24
num_workers: int | None = 0
@dataclasses.dataclass
class DetectionHyperparameters(Hyperparameters):
num_crops: int = 4
import dataclasses
import inspect
from typing import Any, Final
from typing_extensions import Self
NO_CACHE_FILES_ERR_MSG: Final[str] = "The loaded dataset has no cache files."
@dataclasses.dataclass
class Hyperparameters:
@classmethod
def from_dict(cls, **kwargs: Any) -> Self:
"""Create an instance from a dictionary of arguments.
Args:
**kwargs: Keyword arguments for initialization and attribute setting.
Returns:
An initialized instance of the class.
"""
cls_parameters = inspect.signature(cls).parameters
instance = cls(**{k: kwargs.pop(k) for k in list(kwargs) if k in cls_parameters})
for k, v in kwargs.items():
setattr(instance, k, v)
return instance
@dataclasses.dataclass
class TransformHyperparameters(Hyperparameters):
percent_to_mask: float = 0.0
scale: tuple[float, float] = (1e-4, 1.0)
shallow: float = 1.0
rhos_init: float = -4.0
seed: int | None = None
@dataclasses.dataclass
class LossHyperparameters(Hyperparameters):
reduction: str = "mean"
@dataclasses.dataclass
class TransformLossHyperparameters(Hyperparameters):
alpha: float = 0.5
@dataclasses.dataclass
class OptimizerHyperparameters(Hyperparameters):
lr: float = 3e-3
fused: bool = True
@dataclasses.dataclass
class MetricHyperparameters(Hyperparameters): ...
@dataclasses.dataclass
class DatasetHyperparameters(Hyperparameters):
path: str = "detection-datasets/coco"
"""The name of a dataset on https://huggingface.co/datasets or the path to a local dataset."""
cache_dir: str | None = "/data_fast"
"""The directory to store the downloaded dataset in."""
num_proc: int | None = 2
"""The number of workers to use when resizing the dataset."""
@dataclasses.dataclass
class DataLoaderHyperparameters(Hyperparameters):
batch_size: int = 24
num_workers: int | None = 0
@dataclasses.dataclass
class DetectionHyperparameters(Hyperparameters):
num_crops: int = 4
In [ ]:
Copied!
from typing_extensions import TypedDict
from stainedglass_core.utils.torch import nn as sg_nn
class Metrics(TypedDict):
"""Models the metrics collected in the different modes of `StainedGlassTensorflowVisionModule`."""
train: TrainingStepMetrics
valid: TrainingStepMetrics
class TrainingStepMetrics(nn.Module):
"""Common metrics to be collected during `_step` in both training and validation."""
def __init__(self) -> None:
super().__init__()
self.mean_losses = sg_nn.ModuleDefaultDict(torchmetrics.MeanMetric)
self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=1000)
self.precision = torchmetrics.Precision("multiclass", num_classes=1000)
from typing_extensions import TypedDict
from stainedglass_core.utils.torch import nn as sg_nn
class Metrics(TypedDict):
"""Models the metrics collected in the different modes of `StainedGlassTensorflowVisionModule`."""
train: TrainingStepMetrics
valid: TrainingStepMetrics
class TrainingStepMetrics(nn.Module):
"""Common metrics to be collected during `_step` in both training and validation."""
def __init__(self) -> None:
super().__init__()
self.mean_losses = sg_nn.ModuleDefaultDict(torchmetrics.MeanMetric)
self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=1000)
self.precision = torchmetrics.Precision("multiclass", num_classes=1000)
In [ ]:
Copied!
class HuggingFaceDataModule(lightning.LightningDataModule):
def __init__(
self,
image_size: tuple[int, int],
dataset: DatasetHyperparameters,
train_dataloader: DataLoaderHyperparameters,
test_dataloader: DataLoaderHyperparameters,
) -> None:
super().__init__()
self.image_size = image_size
self.dataset_hyperparameters = dataset
self.train_dataloader_hyperparameters = train_dataloader
self.test_dataloader_hyperparameters = test_dataloader
def _prepare_data(self) -> datasets.DatasetDict:
"""Download and prepare the dataset splits."""
resize_transform = torchvision.transforms.v2.Compose(
[
torchvision.transforms.v2.Resize(self.image_size),
torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
torchvision.transforms.v2.RGB(),
]
)
def resize_map(
sample: DataPoint | BatchedDataPoint,
) -> DataPoint | BatchedDataPoint:
if isinstance(sample["image"], list):
return {"image": [resize_transform(image) for image in sample["image"]]}
return {"image": resize_transform(sample["image"])}
path = pathlib.Path(self.dataset_hyperparameters.path)
if path.is_dir():
dataset = datasets.load_from_disk(str(path))
elif "__file__" in globals() and (pathlib.Path(__file__).parent / path).is_dir():
path = pathlib.Path(__file__).parent / path
dataset = datasets.load_from_disk(str(path))
else:
dataset = datasets.load_dataset(
self.dataset_hyperparameters.path,
cache_dir=self.dataset_hyperparameters.cache_dir,
)
assert isinstance(dataset, datasets.DatasetDict)
def _get_cache_dir(dataset: datasets.Dataset) -> pathlib.Path:
if not dataset.cache_files:
raise ValueError(NO_CACHE_FILES_ERR_MSG)
return pathlib.Path(dataset.cache_files[-1]["filename"]).parent
def _get_cache_file_names() -> dict[str, str | None]:
cache_files: dict[str, str | None] = {
str(split_name): str(_get_cache_dir(dataset_split) / "cache" / f"{split_name}.map")
for split_name, dataset_split in dataset.items()
}
return cache_files
return (
dataset.remove_columns(["width", "height", "objects"])
.with_format("torch")
.map(
resize_map,
batched=False,
load_from_cache_file=True,
cache_file_names=_get_cache_file_names(),
num_proc=self.dataset_hyperparameters.num_proc,
)
)
@override
def prepare_data(self) -> None:
"""Download and prepare the dataset splits.
This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
"""
_ = self._prepare_data()
@override
def setup(self, stage: str) -> None:
"""Set up the datasets.
Args:
stage: Whether we are fitting (training), testing or predicting.
"""
dataset = self._prepare_data()
self.train_dataset = dataset["train"]
self.val_dataset = dataset["val"]
@override
def train_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.train_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.train_dataloader_hyperparameters),
shuffle=True,
pin_memory=True,
)
@override
def val_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.val_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.test_dataloader_hyperparameters),
shuffle=False,
pin_memory=True,
)
class HuggingFaceDataModule(lightning.LightningDataModule):
def __init__(
self,
image_size: tuple[int, int],
dataset: DatasetHyperparameters,
train_dataloader: DataLoaderHyperparameters,
test_dataloader: DataLoaderHyperparameters,
) -> None:
super().__init__()
self.image_size = image_size
self.dataset_hyperparameters = dataset
self.train_dataloader_hyperparameters = train_dataloader
self.test_dataloader_hyperparameters = test_dataloader
def _prepare_data(self) -> datasets.DatasetDict:
"""Download and prepare the dataset splits."""
resize_transform = torchvision.transforms.v2.Compose(
[
torchvision.transforms.v2.Resize(self.image_size),
torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
torchvision.transforms.v2.RGB(),
]
)
def resize_map(
sample: DataPoint | BatchedDataPoint,
) -> DataPoint | BatchedDataPoint:
if isinstance(sample["image"], list):
return {"image": [resize_transform(image) for image in sample["image"]]}
return {"image": resize_transform(sample["image"])}
path = pathlib.Path(self.dataset_hyperparameters.path)
if path.is_dir():
dataset = datasets.load_from_disk(str(path))
elif "__file__" in globals() and (pathlib.Path(__file__).parent / path).is_dir():
path = pathlib.Path(__file__).parent / path
dataset = datasets.load_from_disk(str(path))
else:
dataset = datasets.load_dataset(
self.dataset_hyperparameters.path,
cache_dir=self.dataset_hyperparameters.cache_dir,
)
assert isinstance(dataset, datasets.DatasetDict)
def _get_cache_dir(dataset: datasets.Dataset) -> pathlib.Path:
if not dataset.cache_files:
raise ValueError(NO_CACHE_FILES_ERR_MSG)
return pathlib.Path(dataset.cache_files[-1]["filename"]).parent
def _get_cache_file_names() -> dict[str, str | None]:
cache_files: dict[str, str | None] = {
str(split_name): str(_get_cache_dir(dataset_split) / "cache" / f"{split_name}.map")
for split_name, dataset_split in dataset.items()
}
return cache_files
return (
dataset.remove_columns(["width", "height", "objects"])
.with_format("torch")
.map(
resize_map,
batched=False,
load_from_cache_file=True,
cache_file_names=_get_cache_file_names(),
num_proc=self.dataset_hyperparameters.num_proc,
)
)
@override
def prepare_data(self) -> None:
"""Download and prepare the dataset splits.
This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
"""
_ = self._prepare_data()
@override
def setup(self, stage: str) -> None:
"""Set up the datasets.
Args:
stage: Whether we are fitting (training), testing or predicting.
"""
dataset = self._prepare_data()
self.train_dataset = dataset["train"]
self.val_dataset = dataset["val"]
@override
def train_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.train_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.train_dataloader_hyperparameters),
shuffle=True,
pin_memory=True,
)
@override
def val_dataloader(self) -> torch.utils.data.DataLoader[DataPoint]:
"""Create the training dataloader.
Returns:
The training dataloader.
"""
return torch.utils.data.DataLoader(
self.val_dataset, # pyright: ignore[reportArgumentType]
**dataclasses.asdict(self.test_dataloader_hyperparameters),
shuffle=False,
pin_memory=True,
)
In [ ]:
Copied!
class StainedGlassTensorflowVisionModule(lightning.LightningModule):
def __init__(
self,
transform: TransformHyperparameters,
detection: DetectionHyperparameters,
loss: LossHyperparameters,
transform_loss: TransformLossHyperparameters,
optimizer: OptimizerHyperparameters,
metric: MetricHyperparameters,
) -> None:
super().__init__()
self.transform_hyperparameters = transform
self.detection_hyperparameters = detection
self.loss_hyperparameters = loss
self.transform_loss_hyperparameters = transform_loss
self.optimizer_hyperparameters = optimizer
self.metric_hyperparameters = metric
@functools.cached_property
def image_size(self) -> tuple[int, int]:
"""Get the input image dimensions from the underlying model.
Returns:
A tuple of (height, width).
"""
_, height, width, color_channels = self.noisy_model.base_model.tf_model.input_shape
assert isinstance(height, int) and isinstance(width, int) and isinstance(color_channels, int)
assert color_channels in (1, 3)
return height, width
@property
def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None: # pyright: ignore[reportPrivateImportUsage]
"""Access the TensorBoard SummaryWriter from the available loggers.
Returns:
The TensorBoard experiment instance if found, otherwise None.
"""
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
return logger.experiment
return None
@override
def configure_model(self) -> None:
"""Configure the model assuming a non-ddp parallelism."""
self.detection_model = rfdetr.RFDETRNano()
# patch RF-DETR model train method with nn.Module.train
with tf.device("/device:GPU:0" if torch.cuda.is_available() else "/CPU:0"):
base_model_tf = keras.applications.EfficientNetB0(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation=None, # pyright: ignore[reportArgumentType]
)
self.classifier = sg_tensorflow.VisionTensorFlowAdapter(base_model_tf)
self.noisy_model = sg_model.NoisyModel(
sg_noise_layer.CloakNoiseLayerOneShot,
self.classifier,
target_parameter="input",
**dataclasses.asdict(self.transform_hyperparameters),
)
self.task_loss = nn.CrossEntropyLoss(
**dataclasses.asdict(self.loss_hyperparameters),
)
self.loss_function, self.get_loss_components, _ = sg_cloak_loss.composite_cloak_loss_factory(
self.noisy_model,
loss_function=self.task_loss,
**dataclasses.asdict(self.transform_loss_hyperparameters),
)
self.task_loss_name = (
self.task_loss.__name__
if isinstance(self.task_loss, types.FunctionType)
else re.sub(r"(?<!^)(?=[A-Z])", "_", type(self.task_loss).__name__).lower()
)
self.resizer = torchvision.transforms.v2.Resize(self.image_size)
self.configure_metrics()
def configure_metrics(self) -> None:
"""Register the metrics used during training and validation."""
with (
self.device,
sg_utils.torch.dtypes.default_dtype(torch.float32),
):
self._train_metrics = TrainingStepMetrics()
self._val_metrics = TrainingStepMetrics()
# we cannot use ModuleDict here because the 'train' key conflicts with the 'Module.train' method, also TypedDict is based
self.metrics: Metrics = {
"train": self._train_metrics,
"valid": self._val_metrics,
}
@override
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.AdamW(
self.noisy_model.noise_layer.parameters(),
**dataclasses.asdict(self.optimizer_hyperparameters),
weight_decay=0.0, # nonzero weight_decay for SGT parameters hinders training
)
@override
def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "train")
@override
def validation_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "valid")
@override
def on_train_epoch_end(self) -> None:
return self._on_epoch_end("train")
@override
def on_validation_epoch_end(self) -> None:
return self._on_epoch_end("valid")
def _crop_batch_to_detection_boxes(
self,
detection_results: list[sv.Detections],
image_batch: torch.Tensor,
) -> torch.Tensor:
"""Crop a batch of images according to detector results, returning crops resized to the original size."""
assert len(detection_results) == len(image_batch)
cropped_images: list[torch.Tensor] = []
for detections, full_image in zip(detection_results, image_batch, strict=True):
if len(detections) == 0:
cropped_images.append(full_image)
continue
xyxy_boxes = detections.xyxy
if xyxy_boxes is None or len(xyxy_boxes) == 0:
cropped_images.append(full_image)
continue
image_height, image_width = full_image.shape[1:]
sampled_indices = random.choices( # noqa: S311
range(len(xyxy_boxes)), k=self.detection_hyperparameters.num_crops
)
for box_index in sampled_indices:
x1, y1, x2, y2 = xyxy_boxes[box_index]
top = max(0, min(int(y1), image_height - 1))
left = max(0, min(int(x1), image_width - 1))
bottom = max(top + 1, min(int(y2), image_height))
right = max(left + 1, min(int(x2), image_width))
cropped_image = torchvision.transforms.v2.functional.crop(
full_image,
top=top,
left=left,
height=bottom - top,
width=right - left,
)
cropped_image = self.resizer(cropped_image)
cropped_images.append(cropped_image)
return torch.stack(cropped_images)
def _step(self, batch: DataPoint, batch_idx: int, mode: Literal["train", "valid"]) -> torch.Tensor:
raw_images = batch["image"]
with torch.no_grad():
detection_results = self.detection_model.predict(list(raw_images), threshold=0.5)
if isinstance(detection_results, sv.Detections):
detection_results = [detection_results]
cropped_raw_images = self._crop_batch_to_detection_boxes(detection_results, raw_images)
raw_logits = self.classifier(cropped_raw_images)
# use the model's maximum likelihood logit as the label
labels = raw_logits.argmax(dim=1)
transformed_images = self.noisy_model.noise_layer(raw_images)
cropped_transformed_images = self._crop_batch_to_detection_boxes(detection_results, transformed_images)
transformed_logits = self.classifier(cropped_transformed_images)
preds = transformed_logits.argmax(dim=1)
loss = self.loss_function(transformed_logits, target=labels)
num_batches = self.trainer.num_training_batches if self.trainer.num_training_batches != math.inf else 0
global_step = int(self.current_epoch * num_batches + batch_idx)
metrics = {
"accuracy": self.metrics[mode].accuracy(preds, labels),
"precision": self.metrics[mode].precision(preds, labels),
**{self.task_loss_name if name == "task_loss" else name: value for name, value in self.get_loss_components().items()},
}
self.log_dict(metrics, prog_bar=True, logger=False)
if self.logger is not None:
self.logger.log_metrics(
{f"{name}/{mode}/batch": value for name, value in metrics.items()},
step=global_step,
)
def side_by_side_grid(raw: torch.Tensor, transformed: torch.Tensor) -> torch.Tensor:
side_by_side = torch.stack([raw, transformed], dim=1).reshape(-1, *raw.shape[1:])
return torchvision.utils.make_grid(side_by_side)
side_by_side_crops_grid = side_by_side_grid(cropped_raw_images, cropped_transformed_images)
side_by_side_full_grid = side_by_side_grid(raw_images, transformed_images)
if self.tb_writer is not None:
self.tb_writer.add_image(
f"side_by_side_cropped/{mode}/batch",
side_by_side_crops_grid,
global_step=global_step,
)
self.tb_writer.add_image(
f"side_by_side_full/{mode}/batch",
side_by_side_full_grid,
global_step=global_step,
)
self.tb_writer.add_histogram(
f"std_histogram/{mode}/batch",
self.noisy_model.noise_layer.std_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
self.tb_writer.add_histogram(
f"mean_histogram/{mode}/batch",
self.noisy_model.noise_layer.mean_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
return loss
def _on_epoch_end(self, mode: Literal["train", "valid"]) -> None:
"""Compute epoch-level training metrics."""
metrics: dict[str, float] = {
"epoch": self.current_epoch,
}
for loss_name, metric in self.metrics[mode].mean_losses.items():
metrics[loss_name] = metric.compute().item()
metric.reset()
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(
{f"{name}/{mode}/epoch": value for name, value in metrics.items()},
step=self.current_epoch,
)
class StainedGlassTensorflowVisionModule(lightning.LightningModule):
def __init__(
self,
transform: TransformHyperparameters,
detection: DetectionHyperparameters,
loss: LossHyperparameters,
transform_loss: TransformLossHyperparameters,
optimizer: OptimizerHyperparameters,
metric: MetricHyperparameters,
) -> None:
super().__init__()
self.transform_hyperparameters = transform
self.detection_hyperparameters = detection
self.loss_hyperparameters = loss
self.transform_loss_hyperparameters = transform_loss
self.optimizer_hyperparameters = optimizer
self.metric_hyperparameters = metric
@functools.cached_property
def image_size(self) -> tuple[int, int]:
"""Get the input image dimensions from the underlying model.
Returns:
A tuple of (height, width).
"""
_, height, width, color_channels = self.noisy_model.base_model.tf_model.input_shape
assert isinstance(height, int) and isinstance(width, int) and isinstance(color_channels, int)
assert color_channels in (1, 3)
return height, width
@property
def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None: # pyright: ignore[reportPrivateImportUsage]
"""Access the TensorBoard SummaryWriter from the available loggers.
Returns:
The TensorBoard experiment instance if found, otherwise None.
"""
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
return logger.experiment
return None
@override
def configure_model(self) -> None:
"""Configure the model assuming a non-ddp parallelism."""
self.detection_model = rfdetr.RFDETRNano()
# patch RF-DETR model train method with nn.Module.train
with tf.device("/device:GPU:0" if torch.cuda.is_available() else "/CPU:0"):
base_model_tf = keras.applications.EfficientNetB0(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation=None, # pyright: ignore[reportArgumentType]
)
self.classifier = sg_tensorflow.VisionTensorFlowAdapter(base_model_tf)
self.noisy_model = sg_model.NoisyModel(
sg_noise_layer.CloakNoiseLayerOneShot,
self.classifier,
target_parameter="input",
**dataclasses.asdict(self.transform_hyperparameters),
)
self.task_loss = nn.CrossEntropyLoss(
**dataclasses.asdict(self.loss_hyperparameters),
)
self.loss_function, self.get_loss_components, _ = sg_cloak_loss.composite_cloak_loss_factory(
self.noisy_model,
loss_function=self.task_loss,
**dataclasses.asdict(self.transform_loss_hyperparameters),
)
self.task_loss_name = (
self.task_loss.__name__
if isinstance(self.task_loss, types.FunctionType)
else re.sub(r"(? None:
"""Register the metrics used during training and validation."""
with (
self.device,
sg_utils.torch.dtypes.default_dtype(torch.float32),
):
self._train_metrics = TrainingStepMetrics()
self._val_metrics = TrainingStepMetrics()
# we cannot use ModuleDict here because the 'train' key conflicts with the 'Module.train' method, also TypedDict is based
self.metrics: Metrics = {
"train": self._train_metrics,
"valid": self._val_metrics,
}
@override
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.AdamW(
self.noisy_model.noise_layer.parameters(),
**dataclasses.asdict(self.optimizer_hyperparameters),
weight_decay=0.0, # nonzero weight_decay for SGT parameters hinders training
)
@override
def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "train")
@override
def validation_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
return self._step(batch, batch_idx, "valid")
@override
def on_train_epoch_end(self) -> None:
return self._on_epoch_end("train")
@override
def on_validation_epoch_end(self) -> None:
return self._on_epoch_end("valid")
def _crop_batch_to_detection_boxes(
self,
detection_results: list[sv.Detections],
image_batch: torch.Tensor,
) -> torch.Tensor:
"""Crop a batch of images according to detector results, returning crops resized to the original size."""
assert len(detection_results) == len(image_batch)
cropped_images: list[torch.Tensor] = []
for detections, full_image in zip(detection_results, image_batch, strict=True):
if len(detections) == 0:
cropped_images.append(full_image)
continue
xyxy_boxes = detections.xyxy
if xyxy_boxes is None or len(xyxy_boxes) == 0:
cropped_images.append(full_image)
continue
image_height, image_width = full_image.shape[1:]
sampled_indices = random.choices( # noqa: S311
range(len(xyxy_boxes)), k=self.detection_hyperparameters.num_crops
)
for box_index in sampled_indices:
x1, y1, x2, y2 = xyxy_boxes[box_index]
top = max(0, min(int(y1), image_height - 1))
left = max(0, min(int(x1), image_width - 1))
bottom = max(top + 1, min(int(y2), image_height))
right = max(left + 1, min(int(x2), image_width))
cropped_image = torchvision.transforms.v2.functional.crop(
full_image,
top=top,
left=left,
height=bottom - top,
width=right - left,
)
cropped_image = self.resizer(cropped_image)
cropped_images.append(cropped_image)
return torch.stack(cropped_images)
def _step(self, batch: DataPoint, batch_idx: int, mode: Literal["train", "valid"]) -> torch.Tensor:
raw_images = batch["image"]
with torch.no_grad():
detection_results = self.detection_model.predict(list(raw_images), threshold=0.5)
if isinstance(detection_results, sv.Detections):
detection_results = [detection_results]
cropped_raw_images = self._crop_batch_to_detection_boxes(detection_results, raw_images)
raw_logits = self.classifier(cropped_raw_images)
# use the model's maximum likelihood logit as the label
labels = raw_logits.argmax(dim=1)
transformed_images = self.noisy_model.noise_layer(raw_images)
cropped_transformed_images = self._crop_batch_to_detection_boxes(detection_results, transformed_images)
transformed_logits = self.classifier(cropped_transformed_images)
preds = transformed_logits.argmax(dim=1)
loss = self.loss_function(transformed_logits, target=labels)
num_batches = self.trainer.num_training_batches if self.trainer.num_training_batches != math.inf else 0
global_step = int(self.current_epoch * num_batches + batch_idx)
metrics = {
"accuracy": self.metrics[mode].accuracy(preds, labels),
"precision": self.metrics[mode].precision(preds, labels),
**{self.task_loss_name if name == "task_loss" else name: value for name, value in self.get_loss_components().items()},
}
self.log_dict(metrics, prog_bar=True, logger=False)
if self.logger is not None:
self.logger.log_metrics(
{f"{name}/{mode}/batch": value for name, value in metrics.items()},
step=global_step,
)
def side_by_side_grid(raw: torch.Tensor, transformed: torch.Tensor) -> torch.Tensor:
side_by_side = torch.stack([raw, transformed], dim=1).reshape(-1, *raw.shape[1:])
return torchvision.utils.make_grid(side_by_side)
side_by_side_crops_grid = side_by_side_grid(cropped_raw_images, cropped_transformed_images)
side_by_side_full_grid = side_by_side_grid(raw_images, transformed_images)
if self.tb_writer is not None:
self.tb_writer.add_image(
f"side_by_side_cropped/{mode}/batch",
side_by_side_crops_grid,
global_step=global_step,
)
self.tb_writer.add_image(
f"side_by_side_full/{mode}/batch",
side_by_side_full_grid,
global_step=global_step,
)
self.tb_writer.add_histogram(
f"std_histogram/{mode}/batch",
self.noisy_model.noise_layer.std_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
self.tb_writer.add_histogram(
f"mean_histogram/{mode}/batch",
self.noisy_model.noise_layer.mean_estimator.module.weight.detach(),
global_step=global_step,
bins=512, # pyright: ignore[reportArgumentType]
)
return loss
def _on_epoch_end(self, mode: Literal["train", "valid"]) -> None:
"""Compute epoch-level training metrics."""
metrics: dict[str, float] = {
"epoch": self.current_epoch,
}
for loss_name, metric in self.metrics[mode].mean_losses.items():
metrics[loss_name] = metric.compute().item()
metric.reset()
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(
{f"{name}/{mode}/epoch": value for name, value in metrics.items()},
step=self.current_epoch,
)
In [ ]:
Copied!
torch.set_float32_matmul_precision("high")
logger = lightning.pytorch.loggers.TensorBoardLogger(
"tensorflow_efficientnet",
name="tensorflow_efficientnet_fullcoco",
)
trainer = lightning.Trainer(
logger=logger,
devices=[0],
max_epochs=50,
log_every_n_steps=30,
fast_dev_run=True, # only enable in CI
)
with trainer.init_module():
training_module = StainedGlassTensorflowVisionModule(
transform=TransformHyperparameters(),
detection=DetectionHyperparameters(),
loss=LossHyperparameters(),
transform_loss=TransformLossHyperparameters(),
optimizer=OptimizerHyperparameters(),
metric=MetricHyperparameters(),
)
# we must instantiate the model before we can know what size to reshape the images in the dataset to
training_module.configure_model()
data_module = HuggingFaceDataModule(
training_module.image_size,
dataset=DatasetHyperparameters(path="mini_coco", num_proc=None),
train_dataloader=DataLoaderHyperparameters(batch_size=1),
test_dataloader=DataLoaderHyperparameters(batch_size=1),
)
trainer.fit(model=training_module, datamodule=data_module)
torch.set_float32_matmul_precision("high")
logger = lightning.pytorch.loggers.TensorBoardLogger(
"tensorflow_efficientnet",
name="tensorflow_efficientnet_fullcoco",
)
trainer = lightning.Trainer(
logger=logger,
devices=[0],
max_epochs=50,
log_every_n_steps=30,
fast_dev_run=True, # only enable in CI
)
with trainer.init_module():
training_module = StainedGlassTensorflowVisionModule(
transform=TransformHyperparameters(),
detection=DetectionHyperparameters(),
loss=LossHyperparameters(),
transform_loss=TransformLossHyperparameters(),
optimizer=OptimizerHyperparameters(),
metric=MetricHyperparameters(),
)
# we must instantiate the model before we can know what size to reshape the images in the dataset to
training_module.configure_model()
data_module = HuggingFaceDataModule(
training_module.image_size,
dataset=DatasetHyperparameters(path="mini_coco", num_proc=None),
train_dataloader=DataLoaderHyperparameters(batch_size=1),
test_dataloader=DataLoaderHyperparameters(batch_size=1),
)
trainer.fit(model=training_module, datamodule=data_module)
In [ ]:
Copied!