Cancer image classification¶
The purpose of this notebook is to demonstrate how to use Stained Glass Core to create a Stained Glass for a pretrained base model.
Adding Stained Glass Core to a training/testing loop only requires the following changes in this order:
- Wrap the base model in a Stained Glass Model after the base model is initialized.
- Wrap the criterion/loss function in a Stained Glass Loss function after the base criterion/loss is initialized.
In addition to those changes above for training/testing, this notebook also demonstrates how to load a Stained Glass Model from a checkpoint, and how to visualize images after applying a prepared Stained Glass Transform.
Dataset¶
This tutorial will use the Lung and Colon Cancer Histopathological Images
dataset, available from Kaggle. This dataset contains 25,000 histopathological images with 5 classes. All images are 768x768 pixels in size and are in jpeg
format.
There are five classes in the dataset, each with 5,000 images, being:
- Lung benign tissue
- Lung adenocarcinoma
- Lung squamous cell carcinoma
- Colon adenocarcinoma
- Colon benign tissue
Original Article Borkowski AA, Bui MM, Thomas LB, Wilson CP, DeLand LA, Mastorides SM. Lung and Colon Cancer Histopathological Image Dataset (LC25000). arXiv:1912.12142v1 [eess.IV], 2019
Relevant Links https://arxiv.org/abs/1912.12142v1 https://github.com/tampapath/lung_colon_image_set
Dataset BibTeX
Since we will be using this pretained model, we will pre-process the dataset as described in its training notebook.
First, download the dataset from Kaggle and unzip it. Then follow the guide below.
from __future__ import annotations
import os
%pip -q install datasets transformers lightning matplotlib torchvision
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
Note: you may need to restart the kernel to use updated packages.
import copy
import datasets
import lightning
import PIL
import PIL.Image
import PIL.ImageDraw
import torch
import torch.nn.functional as F
import transformers
# Update DATASET_DIR to the location of the unzipped data downloaded from kaggle.
DATASET_DIR = "/data_fast/lung_colon_image_set"
Load the Dataset¶
# Load the dataset using HuggingFace datasets library
dataset = datasets.load_dataset(
"imagefolder", data_dir=DATASET_DIR, split="train", drop_labels=False
)
Resolving data files: 0%| | 0/25000 [00:00<?, ?it/s]
Split the dataset into training and test¶
dataset = dataset.shuffle(seed=42)
train_split = dataset.train_test_split(train_size=0.80, seed=42)
ds = datasets.DatasetDict(
{"train": train_split["train"], "test": train_split["test"]}
)
Display a few samples from the dataset¶
def show_grid_of_examples(
ds, seed: int = 42, examples_per_class: int = 5, size=(90, 90)
):
"""Display a grid of examples from the dataset.
Args:
ds: The dataset to display.
seed: The random seed to use when selecting examples.
examples_per_class: The number of examples to display for each class.
size: The size to resize each example to.
Returns:
A PIL Image containing the grid of examples.
"""
w, h = size
labels = ds["train"].features["label"].names
grid = PIL.Image.new(
mode="RGB", size=(examples_per_class * w, len(labels) * h)
)
draw = PIL.ImageDraw.Draw(grid)
print(labels)
examples: dict[str, list] = {
label_id: [] for label_id in range(len(labels))
}
for example in ds["test"]:
if not any(
len(value) < examples_per_class for value in examples.values()
):
break
label_id = example["label"]
if len(examples[label_id]) < examples_per_class:
examples[label_id].append(example)
for label_id, label in enumerate(labels):
ds_slice = examples[label_id]
for i, example in enumerate(ds_slice):
image = example["image"]
idx = examples_per_class * label_id + i
height_idx, width_idx = divmod(idx, examples_per_class)
box = (width_idx * w, height_idx * h)
grid.paste(image.resize(size), box=box)
draw.text(box, label, (0, 0, 0))
return grid
show_grid_of_examples(ds)
['colon_aca', 'colon_n', 'lung_aca', 'lung_n', 'lung_scc']
Load Pretrained Base Model and Process Dataset¶
We will use the DunnBC22/vit-base-patch16-224-in21k_lung_and_colon_cancer
model from the HuggingFace hub. This model is a Vision Transformer (ViT) model pretrained on this dataset.
We will also use these loaded pretrained feature extractor to pre-process the dataset.
PRETRAINED_MODEL_TAG = (
"DunnBC22/vit-base-patch16-224-in21k_lung_and_colon_cancer"
)
BASE_MODEL = transformers.ViTForImageClassification.from_pretrained(
PRETRAINED_MODEL_TAG
)
FEATURE_EXTRACTOR = transformers.ViTImageProcessor.from_pretrained(
PRETRAINED_MODEL_TAG
)
def preprocess_images(sample_batch: dict) -> dict:
"""Transform a batch of samples from the dataset's dictionary of PIL Images and labels into a dictionary of tensors.
Args:
sample_batch: batch of images and labels loaded from disk via HuggingFace Datasets
Returns:
Processed images.
"""
inputs = FEATURE_EXTRACTOR(list(sample_batch["image"]), return_tensors="pt")
inputs["labels"] = sample_batch["label"]
return inputs
prepared_dataset = ds.map(
preprocess_images, batched=True, load_from_cache_file=True
)
def data_collator(batch: dict) -> dict:
"""Collate the data from the HuggingFace dataset; to be used by a torch DataLoader.
Args:
batch: The data from the huggingface dataset
Returns:
Dictionary with the prepared batch of pixel value and label tensors.
"""
return {
"pixel_values": torch.stack(
[torch.tensor(x["pixel_values"]) for x in batch]
),
"labels": torch.tensor([x["labels"] for x in batch]),
}
Initialize Dataloaders¶
Dataloader parameters are configurable for the system doing training.
BATCH_SIZE = 256
NUM_WORKERS = 16
train_loader = torch.utils.data.DataLoader(
prepared_dataset["train"],
batch_size=BATCH_SIZE,
collate_fn=data_collator,
num_workers=NUM_WORKERS,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
prepared_dataset["test"],
batch_size=BATCH_SIZE,
collate_fn=data_collator,
num_workers=NUM_WORKERS,
shuffle=False,
)
Test Pretrained Base Model¶
We will use Pytorch Lightning to manage training and validation. To do this, we will create a BaseModelLightningModule
. Although we will not train using this Lightning module, specifying the training functions here will make it easier to emphasize the changes to training later on when using Stained Glass Transform.
We will also test the pretrained base model using this Lightning Module.
from typing import Callable
from torch import nn
class BaseModelLightningModule(lightning.LightningModule):
"""Pytorch Lightning Module used to control training and testing of a Pretrained Vision Classification model.
Attributes:
lr: Learning Rate to be used during training.
model: Model to be trained.
loss_function: Loss function to be used for training.
"""
def __init__(
self,
base_model,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr: float,
*,
model_parameters: dict | None = None,
loss_parameters: dict | None = None,
):
super().__init__()
self.lr = lr
base_model = copy.deepcopy(base_model).to("cpu")
self.model = self._prep_model(base_model, **(model_parameters or {}))
self.loss_function = self._prep_loss_function(
loss_function, **(loss_parameters or {})
)
@staticmethod
def _prep_model(base_model: nn.Module):
"""Prepare the model for training.
This method needs to be overridden (wrap the model) for training Stained Glass Transform.
Args:
base_model: Base Model.
Returns:
Prepared model.
"""
return base_model
def _prep_loss_function(
self,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
):
"""Prepare the loss function for training.
This method needs to be overridden (wrap the loss function) for training Stained Glass Transform.
Args:
loss_function: Loss function measuring base model performance.
Returns:
Prepared loss function.
"""
return loss_function
def forward(self, x: torch.Tensor):
"""Run the model forward.
Args:
x: Input tensor
Returns:
Model output.
"""
return self.model(x)
def evaluate(
self, logits: torch.Tensor, batch_labels: torch.Tensor, stage: str
) -> dict[str, torch.Tensor]:
"""Evaluate the model performance (from output logits and labels) on various metrics.
Args:
logits: Model outputs as a logits tensor.
batch_labels: Labels tensor.
stage: The stage of the training (train/val/test); used for logging.
Return:
Dictionary of metrics (keyed by their respective names).
"""
accuracy = torch.mean(
(torch.argmax(logits, dim=1) == batch_labels).float()
)
return {f"{stage}_accuracy": accuracy}
def training_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""Perform a single training step.
Args:
batch: Single batch of input data.
batch_idx: Batch index.
Returns:
Loss value.
"""
return self._step(batch, "train")
def validation_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""Perform a single validation step.
Args:
batch: Single batch of input data.
batch_idx: Batch index.
Returns:
Loss value.
"""
return self._step(batch, "val")
def test_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""Perform a single testing step.
Args:
batch: Single batch of input data.
batch_idx: Batch index.
Returns:
Loss value.
"""
return self._step(batch, "test")
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure the optimizer to use for training.
Returns:
Configured optimizer.
"""
optimizer = torch.optim.AdamW(
self.parameters(), lr=self.lr, weight_decay=0.0
)
return optimizer
@staticmethod
def freeze_parameters(module: nn.Module) -> None:
"""Freeze all the parameters in a given module.
Args:
module: Module to freeze.
"""
for param in module.parameters():
param.requires_grad = False
@staticmethod
def unfreeze_parameters(module: nn.Module) -> None:
"""Unfreeze all the parameters in a given module.
Args:
module: Module to unfreeze.
"""
for param in module.parameters():
param.requires_grad = True
def _step(
self, batch: tuple[torch.Tensor, torch.Tensor], stage: str
) -> torch.Tensor:
"""Perform a single train/test step.
Note:
This assumes that training and testing steps are the same, just applied to different data.
Args:
batch: Single batch of input data.
stage: The stage of the training (train/val/test); used for logging.
Returns:
Loss value.
"""
pixel_values, labels = batch["pixel_values"], batch["labels"]
# Unwrap the logits from the HuggingFace Transformers output
outputs = self(pixel_values).logits
loss = self.loss_function(outputs, labels)
evaluate_results = self.evaluate(outputs, labels, stage)
self.log_dict({f"{stage}_loss": loss, **evaluate_results})
return loss
# This is just a standard loss function used to train the base model.
# Cross Entropy is a reasonable choice for this classification task.
LOSS_FUNCTION = F.cross_entropy
# Create the Lightning Module
base_model_lightning_module = BaseModelLightningModule(
BASE_MODEL,
loss_function=LOSS_FUNCTION,
lr=3e-4,
)
# Pytorch lightning Testing doesn't work properly without first fitting a Trainer.
# We can do this with a `fit` run with 0 train/val batches.
base_model_trainer = lightning.Trainer(
accelerator="auto", limit_train_batches=0, limit_val_batches=0
)
base_model_trainer.fit(
model=base_model_lightning_module,
train_dataloaders=train_loader,
val_dataloaders=test_loader,
ckpt_path=None,
)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /home/jennifer/.conda/envs/sgc310/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default /home/jennifer/.conda/envs/sgc310/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`. You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3] | Name | Type | Params ---------------------------------------------------- 0 | model | ViTForImageClassification | 85.8 M ---------------------------------------------------- 85.8 M Trainable params 0 Non-trainable params 85.8 M Total params 343.210 Total estimated model params size (MB)
# Test
base_model_test_results = base_model_trainer.test(
model=base_model_lightning_module, ckpt_path=None, dataloaders=test_loader
)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
Testing: | | 0/? [00:00<?, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_accuracy │ 1.0 │ │ test_loss │ 0.00038744747871533036 │ └───────────────────────────┴───────────────────────────┘
Train Stained Glass Transform¶
We will use Pytorch Lightning to manage training and validation. To do this, we will create a StainedGlassTrainerLightningModule
. In its __init__
, we will have to wrap our base model in a NoisyModel
(which applies a Stained Glass Transform for training) and wrap our loss function using hook_loss_wrapper
to allow us to train the Stained Glass Transform.
from stainedglass_core import (
loss as sg_loss,
model as sg_model,
noise_layer as sg_noise_layer,
)
class StainedGlassTrainerLightningModule(BaseModelLightningModule):
"""Pytorch Lightning Module used to control training and testing of a Pretrained Vision Classification
model with Stained Glass Transform.
This subclass highlights all of the differences necessary for training a Stained Glass Transform compared
to training a base model.
The differences are:
1. Wrap the base model in a `NoisyModel` wrapper.
2. Wrap the loss function using `hook_loss_wrapper`.
Attributes:
lr: Learning Rate to be used during training.
model: Model to be trained.
loss_function: Loss function to be used for training.
Args:
base_model: Model to be trained.
loss_function: Loss function to be used for training.
lr: Learning Rate to be used during training.
scale: The range of possible amplitudes for the stochastic portion of Stained Glass Transform.
percent_to_mask: The percentage of features that will be masked out during the transformation.
alpha: Hyperparameter that controls the balance between the model loss (provided by the user) and
the Transform loss (provided by `stainedglass_core`) during training. Value between 0 and 1 or None.
If alpha is None, an adaptive algorithm is used during each step of training.
"""
def __init__(
self,
base_model: nn.Module,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr: float,
scale: tuple[float, float],
percent_to_mask: float,
alpha: float | None,
):
model_parameters = {
"percent_to_mask": percent_to_mask,
"scale": scale,
}
loss_parameters = {"alpha": alpha}
super().__init__(
base_model=base_model,
loss_function=loss_function,
lr=lr,
model_parameters=model_parameters,
loss_parameters=loss_parameters,
)
@staticmethod
def _prep_model(
base_model: nn.Module,
percent_to_mask: float,
scale: tuple[int, int],
) -> sg_model.NoisyModel:
"""Prepare the model for training.
Stained Glass Core supplies a wrapper around the model that applies the transformation, and makes training easier.
To use it, just wrap the base model with the `NoisyModel` object (passing in the parameters for the Stained Glass Transform).
Args:
base_model: Base Model.
percent_to_mask: The percentage of features that will be masked out during the transformation.
scale: The range of possible amplitudes for the stochastic portion of Stained Glass Transform.
Returns:
Prepared model.
"""
# `CloakOneShot` is one available version of the Stained Glass Transform.
return sg_model.NoisyModel(
sg_noise_layer.CloakNoiseLayerOneShot,
base_model,
target_parameter="pixel_values",
percent_to_mask=percent_to_mask,
scale=scale,
)
def _prep_loss_function(
self,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
alpha: float,
):
"""Prepare the loss function for training.
Stained Glass Core also supplies a loss function that wraps around an existing loss function that optimizes
the strength of the Transform while also maintaining model performance.
Args:
loss_function: Loss function measuring base model performance.
alpha: Hyperparameter that controls the balance between the model loss (provided by the user) and
the Transform loss (provided by `stainedglass_core`) during training. Value between 0 and 1 or None.
If alpha is None, an adaptive algorithm is used during each step of training.
Returns:
Prepared loss function.
"""
return sg_loss.cloak.composite_cloak_loss_factory(
self.model, loss_function, alpha
)[0]
def _step(self, batch, stage: str):
"""Perform a single train/test step.
Note:
This assumes that training and testing steps are the same, just applied to different data.
Args:
batch: Single batch of input data.
stage: The stage of the training (train/val/test); used for logging.
Returns:
Loss value.
"""
pixel_values, labels = batch["pixel_values"], batch["labels"]
outputs = self(pixel_values).logits
loss = self.loss_function(outputs, labels)
evaluate_results = self.evaluate(outputs, labels, stage)
self.log_dict({f"{stage}_loss": loss, **evaluate_results})
return loss
STAINED_GLASS_NUM_EPOCHS = 50
STAINED_GLASS_LR = 6e-3
# PERCENT_TO_MASK is the percentage of features that will be masked out during the transformation.
PERCENT_TO_MASK = 0.6
# SCALE is the possible range of amplitudes of the stochastic portion of Stained Glass Transform
SCALE = (1e-3, 1.0)
# ALPHA is a hyperparameter that controls the balance between the model loss (provided by the user)
# and the Transform loss (provided by `stainedglass_core`) during training.
# Values should be between 0 and 1. Smaller values prioritize the model loss.
# This parameter is passed to the loss wrapper during training.
ALPHA = 0.1
# Initize the Stained Glass Lightning module
stained_glass_lightning_module = StainedGlassTrainerLightningModule(
BASE_MODEL,
lr=STAINED_GLASS_LR,
scale=SCALE,
percent_to_mask=PERCENT_TO_MASK,
loss_function=LOSS_FUNCTION,
alpha=ALPHA,
)
# We can freeze the base model weights when training Stained Glass Transform
stained_glass_lightning_module.freeze_parameters(
stained_glass_lightning_module.model.base_model
)
import lightning.pytorch.callbacks
# Train the Stained Glass Transform
stained_glass_trainer = lightning.Trainer(
max_epochs=STAINED_GLASS_NUM_EPOCHS,
accelerator="auto",
callbacks=[
lightning.pytorch.callbacks.EarlyStopping(
monitor="val_loss", mode="min"
)
],
)
stained_glass_trainer.fit(
model=stained_glass_lightning_module,
train_dataloaders=train_loader,
val_dataloaders=test_loader,
)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3] | Name | Type | Params ------------------------------------- 0 | model | NoisyModel | 86.1 M ------------------------------------- 301 K Trainable params 85.8 M Non-trainable params 86.1 M Total params 344.414 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
# Test Stained Glass Transform
stainedglass_test_results = stained_glass_trainer.test(dataloaders=test_loader)
/home/jennifer/.conda/envs/sgc310/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:145: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced. Restoring states from the checkpoint path at /home/jennifer/code/stained-glass-core/docs/sdk/examples/lightning_logs/version_27/checkpoints/epoch=44-step=3555.ckpt LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3] Loaded model weights from the checkpoint at /home/jennifer/code/stained-glass-core/docs/sdk/examples/lightning_logs/version_27/checkpoints/epoch=44-step=3555.ckpt
Testing: | | 0/? [00:00<?, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_accuracy │ 0.9775999784469604 │ │ test_loss │ 0.1900411993265152 │ └───────────────────────────┴───────────────────────────┘
Visualize the Stained Glass Transform¶
from collections.abc import Sequence
import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as TF
def show(imgs: torch.Tensor | Sequence[torch.Tensor]) -> None:
"""Show a tensor or a sequence of tensors in a grid using `matplotlib`."""
if not isinstance(imgs, Sequence):
imgs = [imgs]
_, axs = plt.subplots(ncols=len(imgs), squeeze=False, dpi=160)
for i, img in enumerate(imgs):
img = img.detach()
img = TF.to_pil_image(img)
ax: matplotlib.axes.Axes = axs[0, i]
ax.imshow(np.asarray(img))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
import torchvision
from stainedglass_core import transform as sg_transform
assert isinstance(stained_glass_lightning_module.model, sg_model.NoisyModel)
visualization_manager = (
sg_transform.vision.TransformedImageVisualizationManager(
stained_glass_lightning_module.model,
max_examples=8,
max_color_channels=0,
)
)
output = stained_glass_lightning_module.model(**next(iter(test_loader)))
activation_images = visualization_manager.prepare_activation_images()
assert "mask" in activation_images and "masked_std" in activation_images
show(
[
activation_images["activation"],
activation_images["mean"],
activation_images["mask"],
activation_images["std"],
activation_images["masked_std"],
]
)
(Optionally) Fine-tune base model using Stained Glass Transform¶
After the Stained Glass Transform is trained, we can optionally fine-tune the base model on Stained Glass Transformed inputs. Often, the base model does not need to be fine-tuned, as the Stained Glass Transform is optimized such that the model performs similarly with or without the transform. In some use-cases, however, it can be useful to do this fine-tuning, such as for transfer learning or if the performance gap is deemed too wide.
This fine-tuning occurs by training in the same way, except by unfreezing the base weights. Optionally, the Stained Glass Transform can be frozen. One such example is if fine-tuning a Large Language Model on transformed customer data (the Transform, which runs on the customer's side, would need to remain fixed).
BASE_MODEL_FINE_TUNE_NUM_EPOCHS = 15
BASE_MODEL_FINE_TUNE_LR = 3e-5
# Initialize the Base Model Finetuning Lightning Module
base_model_finetune_lightning_module = StainedGlassTrainerLightningModule(
BASE_MODEL,
lr=BASE_MODEL_FINE_TUNE_LR,
scale=SCALE,
percent_to_mask=PERCENT_TO_MASK,
loss_function=LOSS_FUNCTION,
alpha=ALPHA,
)
# Copy the state dict over from the previous model
base_model_finetune_lightning_module.load_state_dict(
stained_glass_lightning_module.state_dict()
)
# We can freeze the stained glass weights when fine-tuning the Base Model
base_model_finetune_lightning_module.freeze_parameters(
base_model_finetune_lightning_module.model.noise_layer
)
# Finetune the base model
base_model_finetune_trainer = lightning.Trainer(
max_epochs=BASE_MODEL_FINE_TUNE_NUM_EPOCHS,
accelerator="auto",
callbacks=[
lightning.pytorch.callbacks.EarlyStopping(
monitor="val_loss", mode="min", patience=3
)
],
)
base_model_finetune_trainer.fit(
model=base_model_finetune_lightning_module,
train_dataloaders=train_loader,
val_dataloaders=test_loader,
)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3] | Name | Type | Params ------------------------------------- 0 | model | NoisyModel | 86.1 M ------------------------------------- 85.8 M Trainable params 301 K Non-trainable params 86.1 M Total params 344.414 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
# Test the fine-tuned base model (with Stained Glass Transform).
base_model_finetune_test_results = base_model_finetune_trainer.test(
dataloaders=test_loader
)
/home/jennifer/.conda/envs/sgc310/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:145: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced. Restoring states from the checkpoint path at /home/jennifer/code/stained-glass-core/docs/sdk/examples/lightning_logs/version_28/checkpoints/epoch=10-step=869.ckpt LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3] Loaded model weights from the checkpoint at /home/jennifer/code/stained-glass-core/docs/sdk/examples/lightning_logs/version_28/checkpoints/epoch=10-step=869.ckpt
Testing: | | 0/? [00:00<?, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_accuracy │ 0.9901999831199646 │ │ test_loss │ 0.15507125854492188 │ └───────────────────────────┴───────────────────────────┘
Inference¶
For inference, the Stained Glass Transform is applied on the Client Side (within their Zone of Trust). The transformed data is then sent to the model provider (Server Side), where the rest of the model uses that transformed data.
Model Provider Preparation¶
The model provider, after creating a Stained Glass Transform for their model, must distribute it to the client. The model provider does not need to provide any other model weights except for those necessary for the Stained Glass Transform.
# Recall that the NoisyModel is the `model` attribute of the LightningModule we used for training.
# Extract the Stained Glass Transform (this is delivered to the client side)
stained_glass_transform = stained_glass_lightning_module.model.noise_layer
# Extract the Base Model (retained by the Model Provider)
# In this case, we will copy the weights from the `NoisyModel`
# Note, it is insufficient to simply copy the base model module from the `NoisyModel`; there is a
# hook applied which would apply the Stained Glass Transform again to the input of the model.
base_model = copy.deepcopy(BASE_MODEL)
base_model.load_state_dict(
stained_glass_lightning_module.model.base_model.state_dict()
)
<All keys matched successfully>
Client Side¶
On the client side, we apply the Stained Glass Transform. The client sends the transformed data; the original data never leaves the client's Zone of Trust.
num_samples = 8
# Transform a few examples
batch = next(iter(test_loader))["pixel_values"][:num_samples]
transformed_batch = stained_glass_transform(batch)
# Let's visualize the transform. Note, the client would only send the transformed data (bottom row).
# The raw images (top row) are there for reference only.
grid = torchvision.utils.make_grid(
torch.concat([batch, transformed_batch]), nrow=num_samples, padding=10
)
show(grid)
# The client would then send the `transformed_batch` to the server
Server Side¶
On the server side, the model provider receives the transformed data, and then processeses it the same way as it would have processed untransformed data.
with torch.no_grad():
results = base_model(transformed_batch).logits
results
tensor([[ 7.8200, -1.2489, -1.6888, -1.6508, -1.6102], [-1.6874, -1.7001, -0.1491, 7.3804, -2.1323], [ 7.4841, 0.3065, -1.4107, -2.1808, -2.3197], [ 7.7923, -0.7217, -1.2400, -1.7945, -2.2788], [ 7.7678, -0.4622, -1.4432, -1.9084, -2.0953], [ 7.1447, 1.0650, -1.4395, -2.4203, -2.4184], [ 7.5682, -0.8286, -1.1861, -1.9447, -1.9555], [-1.3828, -2.5845, 5.5616, -2.4500, -1.7984]])
Comparing Test Results¶
We will graph the test accuracy at each stage of training. We can see that running the model with the Stained Glass Transform is comparable to running with raw data. However, the biggest benefit is that the Stained Glass Transform enables the model to be used with sensitive data that would have otherwise been completely unaccessible.
labels = [
"Base Model\nPre-training",
"Transform\nTraining",
"Base Model\nFinetuning",
]
accuracies = [
base_model_test_results[0]["test_accuracy"],
stainedglass_test_results[0]["test_accuracy"],
base_model_finetune_test_results[0]["test_accuracy"],
]
accuracies_rounded = [round(acc, 3) for acc in accuracies]
fig, ax = plt.subplots()
ax.bar(labels, accuracies)
ax.set_title("Test Accuracy During Different Training Phases")
ax.set_ylabel("Test Accuracy")
ax.set_xlabel("Training Phase")
ax.set_ylim(0, 1.1)
rects = ax.patches
for rect, acc in zip(rects, accuracies_rounded):
height = rect.get_height()
ax.text(
rect.get_x() + rect.get_width() / 2,
height,
acc,
ha="center",
va="bottom",
)
plt.show()