Cancer image classification¶
The purpose of this notebook is to demonstrate how to use Stained Glass Core to create a Stained Glass for a pretrained base model using a pure distillation loss.
Adding Stained Glass Core to a training/testing loop only requires the following changes in this order:
- Wrap the base model in a Stained Glass Model after the base model is initialized.
- Create the distillation loss function after the base criterion/loss is initialized.
In addition to those changes above for training/testing, this notebook also demonstrates how to load a Stained Glass Model from a checkpoint, and how to visualize images after applying a prepared Stained Glass Transform.
Dataset¶
This tutorial will use the Lung and Colon Cancer Histopathological Images dataset, available from Kaggle. This dataset contains 25,000 histopathological images with 5 classes. All images are 768x768 pixels in size and are in jpeg format.
There are five classes in the dataset, each with 5,000 images, being:
- Lung benign tissue
- Lung adenocarcinoma
- Lung squamous cell carcinoma
- Colon adenocarcinoma
- Colon benign tissue
Original Article Borkowski AA, Bui MM, Thomas LB, Wilson CP, DeLand LA, Mastorides SM. Lung and Colon Cancer Histopathological Image Dataset (LC25000). arXiv:1912.12142v1 [eess.IV], 2019
Relevant Links https://arxiv.org/abs/1912.12142v1 https://github.com/tampapath/lung_colon_image_set
Dataset BibTeX
Since we will be using DunnBC22/vit-base-patch16-224-in21k_lung_and_colon_cancer, we will pre-process the dataset as described in its training notebook.
First, download the dataset from Kaggle and unzip it. Then follow the guide below.
from __future__ import annotations
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import copy
from typing import Final
import datasets
import lightning
import PIL
import PIL.Image
import PIL.ImageDraw
import torch
import torch.nn.functional as F
import transformers
torch.set_float32_matmul_precision("high")
/home/kyle/.conda/envs/sgc/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
# Update DATASET_DIR to the location of the unzipped data downloaded from kaggle.
DATASET_DIR: Final[str] = "/data/lung_colon_image_set"
# Disable or enable logging to weights and biases.
USE_WANDB: Final[bool] = True
Load the Dataset¶
# Load the dataset using HuggingFace datasets library
dataset = datasets.load_dataset(
"imagefolder", data_dir=DATASET_DIR, split="train", drop_labels=False
)
Split the dataset into training and test¶
dataset = dataset.shuffle(seed=42)
train_split = dataset.train_test_split(train_size=0.80, seed=42)
ds = datasets.DatasetDict(
{"train": train_split["train"], "test": train_split["test"]}
)
Display a few samples from the dataset¶
from typing import Any
def show_grid_of_examples(
ds, examples_per_class: int = 5, size=(90, 90)
) -> PIL.Image.Image:
"""Display a grid of examples from the dataset.
Args:
ds: The dataset to display.
examples_per_class: The number of examples to display for each class.
examples_per_class: The number of examples to display for each class.
size: The size to resize each example to.
Returns:
A PIL Image containing the grid of examples.
"""
w, h = size
labels = ds["train"].features["label"].names
grid = PIL.Image.new(
mode="RGB", size=(examples_per_class * w, len(labels) * h)
)
draw = PIL.ImageDraw.Draw(grid)
print(labels)
examples = {label_id: [] for label_id in range(len(labels))}
for example in ds["test"]:
if not any(
len(value) < examples_per_class for value in examples.values()
):
break
label_id = example["label"]
if len(examples[label_id]) < examples_per_class:
examples[label_id].append(example)
for label_id, label in enumerate(labels):
ds_slice = examples[label_id]
for i, example in enumerate(ds_slice):
image = example["image"]
idx = examples_per_class * label_id + i
height_idx, width_idx = divmod(idx, examples_per_class)
box = (width_idx * w, height_idx * h)
grid.paste(image.resize(size), box=box)
draw.text(box, label, (0, 0, 0))
return grid
show_grid_of_examples(ds)
['colon_aca', 'colon_n', 'lung_aca', 'lung_n', 'lung_scc']
Load Pretrained Base Model and Process Dataset¶
We will use the DunnBC22/vit-base-patch16-224-in21k_lung_and_colon_cancer model from the HuggingFace hub. This model is a Vision Transformer (ViT) model pretrained on this dataset.
We will also use these loaded pretrained feature extractor to pre-process the dataset.
PRETRAINED_MODEL_TAG = (
"DunnBC22/vit-base-patch16-224-in21k_lung_and_colon_cancer"
)
BASE_MODEL = transformers.ViTForImageClassification.from_pretrained(
PRETRAINED_MODEL_TAG
)
FEATURE_EXTRACTOR = transformers.ViTImageProcessor.from_pretrained(
PRETRAINED_MODEL_TAG
)
from typing import cast
def preprocess_images(sample_batch: dict) -> dict:
"""Transform a batch of samples from the dataset's dictionary of PIL Images and labels into a dictionary of tensors.
Args:
sample_batch: batch of images and labels loaded from disk via HuggingFace Datasets
Returns:
Processed images.
"""
inputs = FEATURE_EXTRACTOR(list(sample_batch["image"]), return_tensors="pt")
inputs["labels"] = sample_batch["label"]
return cast(dict, inputs)
prepared_dataset = ds.map(
preprocess_images, batched=True, load_from_cache_file=True
)
def data_collator(batch: dict) -> dict:
"""Collate the data from the HuggingFace dataset; to be used by a torch DataLoader.
Args:
batch: The data from the huggingface dataset
Returns:
Dictionary with the prepared batch of pixel value and label tensors.
"""
return {
"pixel_values": torch.stack(
[torch.tensor(x["pixel_values"]) for x in batch]
),
"labels": torch.tensor([x["labels"] for x in batch]),
}
Initialize Dataloaders¶
Dataloader parameters are configurable for the system doing training.
BATCH_SIZE: Final[int] = 64
NUM_WORKERS: Final[int] = 4
train_loader = torch.utils.data.DataLoader(
prepared_dataset["train"],
batch_size=BATCH_SIZE,
collate_fn=data_collator,
num_workers=NUM_WORKERS,
shuffle=True,
multiprocessing_context="fork",
pin_memory=True,
persistent_workers=True,
)
test_loader = torch.utils.data.DataLoader(
prepared_dataset["test"],
batch_size=BATCH_SIZE,
collate_fn=data_collator,
num_workers=NUM_WORKERS,
shuffle=False,
multiprocessing_context="fork",
pin_memory=True,
persistent_workers=True,
)
Test Pretrained Base Model¶
We will use Pytorch Lightning to manage training and validation. To do this, we will create a BaseModelLightningModule. Although we will not train using this Lightning module, specifying the training functions here will make it easier to emphasize the changes to training later on when using Stained Glass Transform.
We will also test the pretrained base model using this Lightning Module.
from torch import nn
class BaseModelLightningModule(lightning.LightningModule):
"""Pytorch Lightning Module used to control training and testing of a Pretrained Vision Classification model.
Attributes:
lr: Learning Rate to be used during training.
model: Model to be trained.
loss_function: Loss function to be used for training.
"""
def __init__(
self,
base_model,
lr: float,
*,
model_parameters: dict | None = None,
loss_parameters: dict | None = None,
):
super().__init__()
self.lr = lr
base_model = copy.deepcopy(base_model).to("cpu")
self.model = self._prep_model(base_model, **(model_parameters or {}))
self.loss_function = self._prep_loss_function(**(loss_parameters or {}))
@staticmethod
def _prep_model(base_model: nn.Module) -> nn.Module:
"""Prepare the model for training.
This method needs to be overridden (wrap the model) for training Stained Glass Transform.
Args:
base_model: Base Model.
Returns:
Prepared model.
"""
return base_model
def _prep_loss_function(
self,
):
"""Prepare the loss function for training.
This method needs to be overridden (wrap the loss function) for training Stained Glass Transform.
Returns:
Prepared loss function.
"""
# This is just a standard loss function used to train the base model.
# Cross Entropy is a reasonable choice for this classification task.
return F.cross_entropy
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the model forward.
Args:
x: Input tensor
Returns:
Model output.
"""
return self.model(x)
def evaluate(
self, logits: torch.Tensor, batch_labels: torch.Tensor, stage: str
) -> dict[str, torch.Tensor]:
"""Evaluate the model performance (from output logits and labels) on various metrics.
Args:
logits: Model outputs as a logits tensor.
batch_labels: Labels tensor.
stage: The stage of the training (train/val/test); used for logging.
Return:
Dictionary of metrics (keyed by their respective names).
"""
accuracy = torch.mean(
(torch.argmax(logits, dim=1) == batch_labels).float()
)
return {f"{stage}_accuracy": accuracy}
def training_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""Perform a single training step.
Args:
batch: Single batch of input data.
batch_idx: Batch index.
Returns:
Loss value.
"""
return self._step(batch, "train", batch_idx=batch_idx)
def validation_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""Perform a single validation step.
Args:
batch: Single batch of input data.
batch_idx: Batch index.
Returns:
Loss value.
"""
return self._step(batch, "val")
def test_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""Perform a single testing step.
Args:
batch: Single batch of input data.
batch_idx: Batch index.
Returns:
Loss value.
"""
return self._step(batch, "test")
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure the optimizer to use for training.
Returns:
Configured optimizer.
"""
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.lr,
weight_decay=0.0,
fused=True,
betas=(0.0, 0.95),
)
return optimizer
@staticmethod
def freeze_parameters(module: nn.Module) -> None:
"""Freeze all the parameters in a given module.
Args:
module: Module to freeze.
"""
for param in module.parameters():
param.requires_grad = False
@staticmethod
def unfreeze_parameters(module: nn.Module) -> None:
"""Unfreeze all the parameters in a given module.
Args:
module: Module to unfreeze.
"""
for param in module.parameters():
param.requires_grad = True
def _step(
self,
batch: tuple[torch.Tensor, torch.Tensor],
stage: str,
batch_idx: int | None = None,
) -> torch.Tensor:
"""Perform a single train/test step.
Note:
This assumes that training and testing steps are the same, just applied to different data.
Args:
batch: Single batch of input data.
stage: The stage of the training (train/val/test); used for logging.
batch_idx: The integer batch id.
Returns:
Loss value.
"""
pixel_values, labels = batch["pixel_values"], batch["labels"]
# Unwrap the logits from the HuggingFace Transformers output
outputs = self(pixel_values).logits
loss = self.loss_function(outputs, labels)
evaluate_results = self.evaluate(outputs, labels, stage)
self.log_dict({f"{stage}_loss": loss, **evaluate_results})
return loss
# Create the Lightning Module
base_model_lightning_module = BaseModelLightningModule(
BASE_MODEL,
lr=3e-4,
)
# Pytorch lightning Testing doesn't work properly without first fitting a Trainer.
# We can do this with a `fit` run with 0 train/val batches.
base_model_trainer = lightning.Trainer(
accelerator="auto", limit_train_batches=0, limit_val_batches=0
)
base_model_trainer.fit(
model=base_model_lightning_module,
train_dataloaders=train_loader,
val_dataloaders=test_loader,
ckpt_path=None,
)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /home/kyle/.conda/envs/sgc/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6] | Name | Type | Params | Mode ----------------------------------------------------------- 0 | model | ViTForImageClassification | 85.8 M | eval ----------------------------------------------------------- 85.8 M Trainable params 0 Non-trainable params 85.8 M Total params 343.210 Total estimated model params size (MB) 0 Modules in train mode 214 Modules in eval mode
# Test
base_model_test_results = base_model_trainer.test(
model=base_model_lightning_module,
dataloaders=test_loader,
)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6]
Testing DataLoader 0: 100%|██████████| 79/79 [01:10<00:00, 1.12it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_accuracy │ 1.0 │ │ test_loss │ 0.0003874842659570277 │ └───────────────────────────┴───────────────────────────┘
Train Stained Glass Transform¶
We will use Pytorch Lightning to manage training and validation. To do this, we will create a StainedGlassTrainerLightningModule. In its __init__, we will have to wrap our base model in a NoisyModel (which applies a Stained Glass Transform for training) and wrap our loss function using hook_loss_wrapper to allow us to train the Stained Glass Transform.
import torchvision.transforms.functional as TF
import wandb
from stainedglass_core import model as sg_model, transform as sg_transform
from stainedglass_core.loss import distillation
from stainedglass_core.noise_layer import base, cloak_noise
class StainedGlassTrainerLightningModule(BaseModelLightningModule):
"""Pytorch Lightning Module used to control training and testing of a Pretrained Vision Classification
model with Stained Glass Transform.
This subclass highlights all of the differences necessary for training a Stained Glass Transform compared
to training a base model.
The differences are:
1. Wrap the base model in a `NoisyModel` wrapper.
2. Wrap the loss function using `hook_loss_wrapper`.
Attributes:
lr: Learning Rate to be used during training.
model: Model to be trained.
loss_function: Loss function to be used for training.
Args:
base_model: Model to be trained.
loss_function: Loss function to be used for training.
lr: Learning Rate to be used during training.
scale: The range of possible amplitudes for the stochastic portion of Stained Glass Transform.
percent_to_mask: The percentage of features that will be masked out during the transformation.
"""
def __init__(
self,
base_model: nn.Module,
noise_layer_type: type[
cloak_noise.CloakNoiseLayer1 | cloak_noise.CloakNoiseLayer2
],
lr: float,
scale: tuple[float, float],
percent_to_mask: float | None = None,
value_range: tuple[float | None, float | None] | None = None,
loss_parameters: dict | None = None,
cloak_1_state_dict: dict | None = None,
):
model_parameters: dict[str, Any] = {
"scale": scale,
"noise_layer_type": noise_layer_type,
}
if percent_to_mask is not None:
model_parameters["percent_to_mask"] = percent_to_mask
if value_range is not None:
model_parameters["value_range"] = value_range
super().__init__(
base_model=base_model,
lr=lr,
model_parameters=model_parameters,
loss_parameters=loss_parameters,
)
self.visualization_manager = (
sg_transform.vision.TransformedImageVisualizationManager(
self.model,
max_examples=8,
max_color_channels=0,
)
)
self.cloak_1_state_dict = cloak_1_state_dict
@staticmethod
def _prep_model(
base_model: nn.Module,
scale: tuple[int, int],
noise_layer_type: type[base.BaseNoiseLayer],
percent_to_mask: float | None = None,
value_range: tuple[float | None, float | None] | None = None,
) -> sg_model.NoisyModel:
"""Prepare the model for training.
Stained Glass Core supplies a wrapper around the model that applies the transformation, and makes training easier.
To use it, just wrap the base model with the `NoisyModel` object (passing in the parameters for the Stained Glass Transform).
Args:
base_model: Base Model.
scale: The range of possible amplitudes for the stochastic portion of Stained Glass Transform.
noise_layer_type: The type of noise layer to use.
percent_to_mask: The percentage of features that will be masked out during the transformation.
value_range: The range of values to restrict the image input features to.
Returns:
Prepared model.
"""
# `CloakOneShot` is one available version of the Stained Glass Transform.
kwargs: dict[str, Any] = {}
if percent_to_mask is not None:
kwargs["percent_to_mask"] = percent_to_mask
if value_range is not None:
kwargs["value_range"] = value_range
noise_model = sg_model.NoisyModel( # pyright: ignore[reportCallIssue]
noise_layer_type,
base_model,
target_parameter="pixel_values",
scale=scale, # pyright: ignore[reportCallIssue]
**kwargs,
)
return noise_model
def on_train_start(self) -> None:
if self.cloak_1_state_dict is not None:
assert isinstance(self.model.noise_layer, nn.Module)
self.model.noise_layer.load_state_dict(self.cloak_1_state_dict)
self.model.to(self.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the model forward.
Args:
x: Input tensor
Returns:
Model output.
"""
assert isinstance(self.model, sg_model.NoisyModel)
with self.model.distillation_context():
return self.model(x)
def _prep_loss_function(
self,
mutual_information_loss_weight=1e-1,
distillation_cross_entropy_loss_weight=1.0,
distillation_reverse_kld_loss_weight=1.0,
absolute_cosine_similarity_penalty_weight=0.0,
combine_noise_masks_for_mutual_information_sub_batches=True,
):
"""Prepare the distillation loss function for training.
Returns:
Prepared loss function.
"""
(
distillation_loss,
self.get_loss_components,
self.get_hyperparameters,
) = distillation.mutual_information_cross_entropy_vision_distillation_factory(
self.model, # pyright: ignore[reportArgumentType]
mutual_information_loss_weight=mutual_information_loss_weight
if not isinstance(
self.model.noise_layer, cloak_noise.CloakNoiseLayer2
)
else 0.0,
distillation_cross_entropy_loss_weight=distillation_cross_entropy_loss_weight,
distillation_reverse_kld_loss_weight=distillation_reverse_kld_loss_weight,
absolute_cosine_similarity_penalty_weight=absolute_cosine_similarity_penalty_weight,
combine_noise_masks_for_mutual_information_sub_batches=combine_noise_masks_for_mutual_information_sub_batches,
)
self.get_transformed_output = (
self.model.noise_layer.get_transformed_output_factory()
)
return distillation_loss
def _step(
self, batch, stage: str, batch_idx: int | None = None
) -> torch.Tensor:
"""Perform a single train/test step.
Note:
This assumes that training and testing steps are the same, just applied to different data.
Args:
batch: Single batch of input data.
stage: The stage of the training (train/val/test); used for logging.
batch_idx: The integer batch id.
Returns:
Loss value.
"""
pixel_values, labels = batch["pixel_values"], batch["labels"]
outputs = self(pixel_values).logits
loss = self.loss_function(outputs, labels)
evaluate_results = self.evaluate(outputs, labels, stage)
self.log_dict({f"{stage}_loss": loss, **evaluate_results})
if batch_idx is not None and batch_idx % 50 == 0:
with torch.no_grad():
transformed_output = (
self.get_transformed_output()[0, ...]
.squeeze(0)
.detach()
.clone()
)
transformed_ouput_bucket = {"activation": transformed_output}
self.visualization_manager._clamp_data_and_activations(
transformed_ouput_bucket
)
assert isinstance(transformed_output, torch.Tensor)
if USE_WANDB:
wandb.log(
{
"Transformed Output": wandb.Image(
TF.to_pil_image(
transformed_ouput_bucket["activation"],
mode="RGB",
),
caption="Transformed Output",
)
}
)
return loss
STAINED_GLASS_NUM_EPOCHS = 10
STAINED_GLASS_LR = 6e-3
# PERCENT_TO_MASK is the percentage of features that will be masked out during the transformation.
PERCENT_TO_MASK = 0.5
# SCALE is the possible range of amplitudes of the stochastic portion of Stained Glass Transform
SCALE = (1e-8, 1.0)
# Initialize the Stained Glass Lightning module
import lightning.pytorch.loggers
import wandb.sdk
import stainedglass_core
# Optionally use Weights and Biases Logging
if USE_WANDB:
wandb_logger: lightning.pytorch.loggers.WandbLogger | None = None
wandb.sdk.login()
wandb_logger = lightning.pytorch.loggers.WandbLogger(
project="debug",
save_dir="saved/",
log_model="all",
sync_tensorboard=True,
tags=[
f"stainedglass_core=={stainedglass_core.__version__}",
],
job_type="train",
)
else:
wandb_logger = None
tb_logger = lightning.pytorch.loggers.TensorBoardLogger(
"saved/",
name="debug",
default_hp_metric=False,
version=f"stainedglass_core=={stainedglass_core.__version__}",
)
stained_glass_lightning_module = StainedGlassTrainerLightningModule(
BASE_MODEL,
lr=STAINED_GLASS_LR,
scale=SCALE,
noise_layer_type=cloak_noise.CloakNoiseLayer1,
value_range=(0.0, 1.0),
)
# We can freeze the base model weights when training Stained Glass Transform
stained_glass_lightning_module.freeze_parameters(
stained_glass_lightning_module.model.base_model
)
wandb: Currently logged in as: kyle-protopia-ai (core) to https://protopia.wandb.io. Use `wandb login --relogin` to force relogin
import lightning.pytorch.callbacks
# Train the Stained Glass Transform
stained_glass_trainer = lightning.Trainer(
max_epochs=STAINED_GLASS_NUM_EPOCHS,
accelerator="auto",
num_sanity_val_steps=0,
callbacks=[
lightning.pytorch.callbacks.EarlyStopping(
monitor="val_loss", mode="min"
)
],
logger=[
tb_logger,
*((wandb_logger,) if wandb_logger is not None else ()),
],
precision="32-true",
)
stained_glass_trainer.fit(
model=stained_glass_lightning_module,
train_dataloaders=train_loader,
val_dataloaders=test_loader,
)
if wandb_logger is not None:
wandb_logger.finalize("success")
wandb.finish()
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs
saved/wandb/run-20250829_043904-58sl7r96
/home/kyle/.conda/envs/sgc/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:701: Checkpoint directory saved/debug/stainedglass_core==2.5.0/checkpoints exists and is not empty. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6] | Name | Type | Params | Mode --------------------------------------------- 0 | model | NoisyModel | 85.8 M | train --------------------------------------------- 0 Trainable params 85.8 M Non-trainable params 85.8 M Total params 343.210 Total estimated model params size (MB) 9 Modules in train mode 214 Modules in eval mode
Epoch 9: 100%|██████████| 313/313 [06:27<00:00, 0.81it/s, v_num=7r96]
`Trainer.fit` stopped: `max_epochs=10` reached.
Epoch 9: 100%|██████████| 313/313 [06:30<00:00, 0.80it/s, v_num=7r96]
Run history:
| epoch | ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▇▇██ |
| global_step | ▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██ |
| train_accuracy | ▆▆▅▂▁▄▅▆▆▇▇▇█▇█▆▇████▇▇██████▇████▇█████ |
| train_loss | █▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| val_accuracy | ▁▁██████████████████ |
| val_loss | ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
Run summary:
| epoch | 9 |
| global_step | 3129 |
| train_accuracy | 0.96875 |
| train_loss | 0.38249 |
| val_accuracy | 0.97 |
| val_loss | 0.53208 |
View project at: https://protopia.wandb.io/core/debug
Synced 6 W&B file(s), 70 media file(s), 23 artifact file(s) and 1 other file(s)
saved/wandb/run-20250829_043904-58sl7r96/logs
if USE_WANDB:
wandb_logger: lightning.pytorch.loggers.WandbLogger | None = None
wandb.sdk.login()
wandb_logger = lightning.pytorch.loggers.WandbLogger(
project="debug",
save_dir="saved/",
log_model="all",
sync_tensorboard=True,
tags=[
f"stainedglass_core=={stainedglass_core.__version__}",
],
job_type="train",
)
else:
wandb_logger = None
tb_logger = lightning.pytorch.loggers.TensorBoardLogger(
"saved/",
name="debug",
default_hp_metric=False,
version=f"stainedglass_core=={stainedglass_core.__version__}",
)
stained_glass_lightning_module_2 = StainedGlassTrainerLightningModule(
BASE_MODEL,
lr=STAINED_GLASS_LR,
scale=SCALE,
noise_layer_type=cloak_noise.CloakNoiseLayer2,
value_range=(0.0, 1.0),
percent_to_mask=PERCENT_TO_MASK,
cloak_1_state_dict=stained_glass_lightning_module.model.noise_layer.state_dict(),
)
# We can freeze the base model weights when training Stained Glass Transform
stained_glass_lightning_module_2.freeze_parameters(
stained_glass_lightning_module_2.model.base_model
)
stained_glass_lightning_module_2.freeze_parameters(
stained_glass_lightning_module_2.model.noise_layer.std_estimator
)
stained_glass_trainer_2 = lightning.Trainer(
max_epochs=STAINED_GLASS_NUM_EPOCHS,
accelerator="auto",
num_sanity_val_steps=0,
callbacks=[
lightning.pytorch.callbacks.EarlyStopping(
monitor="val_loss", mode="min"
)
],
logger=[
tb_logger,
*((wandb_logger,) if wandb_logger is not None else ()),
],
precision="32-true",
)
stained_glass_trainer_2.fit(
model=stained_glass_lightning_module_2,
train_dataloaders=train_loader,
val_dataloaders=test_loader,
)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs
saved/wandb/run-20250829_054449-tzz4513n
/home/kyle/.conda/envs/sgc/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:701: Checkpoint directory saved/debug/stainedglass_core==2.5.0/checkpoints exists and is not empty. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6] | Name | Type | Params | Mode --------------------------------------------- 0 | model | NoisyModel | 85.8 M | train --------------------------------------------- 0 Trainable params 85.8 M Non-trainable params 85.8 M Total params 343.210 Total estimated model params size (MB) 10 Modules in train mode 214 Modules in eval mode
Epoch 0: 0%| | 0/313 [00:00<?, ?it/s]
/home/kyle/.conda/envs/sgc/lib/python3.10/site-packages/torch/nn/modules/module.py:2575: UserWarning: percent_to_mask not found in CloakNoiseLayer2 loaded state_dict module._load_from_state_dict(
Epoch 9: 100%|██████████| 313/313 [06:23<00:00, 0.82it/s, v_num=513n]
`Trainer.fit` stopped: `max_epochs=10` reached.
Epoch 9: 100%|██████████| 313/313 [06:26<00:00, 0.81it/s, v_num=513n]
# Test Stained Glass Transform
stainedglass_test_results = stained_glass_trainer_2.test(
dataloaders=test_loader
)
/home/kyle/.conda/envs/sgc/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:149: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced. Restoring states from the checkpoint path at saved/debug/stainedglass_core==2.5.0/checkpoints/epoch=9-step=3130-v1.ckpt LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6] Loaded model weights from the checkpoint at saved/debug/stainedglass_core==2.5.0/checkpoints/epoch=9-step=3130-v1.ckpt
Testing DataLoader 0: 100%|██████████| 79/79 [01:09<00:00, 1.13it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_accuracy │ 0.9549999833106995 │ │ test_loss │ 0.5747208595275879 │ └───────────────────────────┴───────────────────────────┘
Visualize the Stained Glass Transform¶
from collections.abc import Sequence
import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
import torch
def show(imgs: torch.Tensor | Sequence[torch.Tensor]) -> None:
"""Show a tensor or a sequence of tensors in a grid using `matplotlib`."""
if not isinstance(imgs, Sequence):
imgs = [imgs]
_, axs = plt.subplots(ncols=len(imgs), squeeze=False, dpi=160)
for i, img in enumerate(imgs):
img = img.detach()
img = TF.to_pil_image(img)
ax: matplotlib.axes.Axes = axs[0, i]
ax.imshow(np.asarray(img))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
import torchvision
assert isinstance(stained_glass_lightning_module_2.model, sg_model.NoisyModel)
visualization_manager = (
sg_transform.vision.TransformedImageVisualizationManager(
stained_glass_lightning_module_2.model,
max_examples=8,
max_color_channels=0,
)
)
output = stained_glass_lightning_module_2.model(**next(iter(test_loader)))
activation_images = visualization_manager.prepare_activation_images()
assert "mask" in activation_images and "masked_std" in activation_images
show(
[
activation_images["activation"],
activation_images["mean"],
activation_images["mask"],
activation_images["std"],
activation_images["masked_std"],
]
)
Inference¶
For inference, the Stained Glass Transform is applied on the Client Side (within their Zone of Trust). The transformed data is then sent to the model provider (Server Side), where the rest of the model uses that transformed data.
Model Provider Preparation¶
The model provider, after creating a Stained Glass Transform for their model, must distribute it to the client. The model provider does not need to provide any other model weights except for those necessary for the Stained Glass Transform.
# Recall that the NoisyModel is the `model` attribute of the LightningModule we used for training.
# Extract the Stained Glass Transform (this is delivered to the client side)
stained_glass_transform = stained_glass_lightning_module_2.model.noise_layer
# Extract the Base Model (retained by the Model Provider)
# In this case, we will copy the weights from the `NoisyModel`
# Note, it is insufficient to simply copy the base model module from the `NoisyModel`; there is a
# hook applied which would apply the Stained Glass Transform again to the input of the model.
base_model = copy.deepcopy(BASE_MODEL)
base_model.load_state_dict(
stained_glass_lightning_module_2.model.base_model.state_dict()
)
<All keys matched successfully>
Client Side¶
On the client side, we apply the Stained Glass Transform. The client sends the transformed data; the original data never leaves the client's Zone of Trust.
num_samples = 8
# Transform a few examples
batch = next(iter(test_loader))["pixel_values"][:num_samples]
transformed_batch = stained_glass_transform(batch)
# Let's visualize the transform. Note, the client would only send the transformed data (bottom row).
# The raw images (top row) are there for reference only.
grid = torchvision.utils.make_grid(
torch.concat([batch, transformed_batch]),
nrow=num_samples,
padding=10,
)
show(grid)
# The client would then send the `transformed_batch` to the server
Server Side¶
On the server side, the model provider receives the transformed data, and then processeses it the same way as it would have processed untransformed data.
with torch.no_grad():
results = base_model(transformed_batch).logits
results
tensor([[ 7.8287, -1.5252, -1.4735, -1.6478, -1.6110],
[-1.8735, -2.2622, -0.1534, 7.2228, -1.6506],
[ 7.6522, -1.0738, -1.2559, -2.0188, -1.6745],
[ 7.6966, -1.1802, -1.2043, -1.8493, -1.8165],
[ 6.9394, 0.5846, -1.2439, -2.2039, -2.2797],
[ 7.6448, -0.6858, -1.5229, -1.9463, -1.7846],
[ 7.3873, -0.6193, -0.9628, -2.1339, -2.0693],
[-1.3240, -2.6290, 4.8224, -1.3932, -1.5964]])
Comparing Test Results¶
We will graph the test accuracy at each stage of training. We can see that running the model with the Stained Glass Transform is comparable to running with raw data. However, the biggest benefit is that the Stained Glass Transform enables the model to be used with sensitive data that would have otherwise been completely unaccessible.
labels = [
"Base Model\nPre-training",
"Transform\nTraining",
]
accuracies = [
base_model_test_results[0]["test_accuracy"],
stainedglass_test_results[0]["test_accuracy"],
]
accuracies_rounded = [round(acc, 3) for acc in accuracies]
fig, ax = plt.subplots()
ax.bar(labels, accuracies)
ax.set_title("Test Accuracy During Different Training Phases")
ax.set_ylabel("Test Accuracy")
ax.set_xlabel("Training Phase")
ax.set_ylim(0, 1.1)
rects = ax.patches
for rect, acc in zip(rects, accuracies_rounded):
height = rect.get_height()
ax.text(
rect.get_x() + rect.get_width() / 2,
height,
acc,
ha="center",
va="bottom",
)
plt.show()