Creating Stained Glass for Image Classification¶
The purpose of this notebook is to demonstrate how to use Stained Glass Engine to train a Pytorch model. This notebook is adapted from the official Pytorch CIFAR-10 tutorial, which can be found here.
Adding Stained Glass Engine to a training/testing loop only requires a few changes:
Description | Sample Base Code | Sample Code with Stained Glass Core |
---|---|---|
Wrap a PyTorch model with a Stained Glass Transform Model (including extra hyperparameters) | model = base_model |
model = NoisyModel(base_model, noise_layer_class=CloakNoiseLayerOneShot, scale=(.001, 1.0), percent_to_mask=0.6) |
Wrap a loss function with the Stained Glass Transform Loss (including extra hyperparameters) | loss_func = base_loss |
loss_func, get_components, _ = composite_cloak_loss_factory(noisy_model, base_loss, alpha=0.4) |
Call the NoisyModel instead of the base model | output = base_model(input) |
output = noisy_model(input) |
Call the wrapped loss function instead of the base loss function | loss = base_loss(output, target) |
loss = loss_func(output, target) |
See Stained Glass Core in fifteen minutes for more information on these changes.
In addition to those changes above for training/testing, this notebook also demonstrates how to load a Stained Glass Model from a checkpoint, and how to visualize images after applying a prepared Stained Glass Transform.
Noting changes to the original notebook¶
To emphasize the minimal changes needed, every cell in this notebook that has been modified or added is wrapped with this separator:
%%script true
# This is an original cell
model = Net()
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
# This is a added/modified cell
model = stainedglass_core.model.NoisyModel(
stainedglass_core.noise_layer.CloakNoiseLayer1,
model,
target_parameter="input",
)
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
# This is an original cell
for i in dataloader:
model(i)
During this tutorial, we will:
- Prepare and visualize the data
- Define and train the base model (without Stained Glass Engine)
- Wrap the base model with a Stained Glass Transform and train it
- Apply the Stained Glass Transform to the raw data
- Test the model performance with both the raw data and and transformed data.
%pip install -q tqdm matplotlib
[notice] A new release of pip is available: 24.0 -> 24.1.1 [notice] To update, run: pip install --upgrade pip Note: you may need to restart the kernel to use updated packages.
%matplotlib inline
Training a Classifier¶
CIFAR-10 Dataset¶
We will use the CIFAR-10 dataset (loaded via torchvision) to train a classifier.
It has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.

Training an image classifier¶
We will do the following steps in order:
- Load and normalize the CIFAR-10 training and test datasets using
torchvision
- Define a Convolutional Neural Network
- Define a loss function
- Train the network on the training data
- Test the network on the test data
Load and normalize CIFAR-10¶
Using torchvision
, it’s extremely easy to load CIFAR-10.
import torch
import torchvision
import torchvision.transforms as transforms
BATCH_SIZE = 1024
NUM_WORKERS = 0
The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1].
TRAIN_DATA_PATH = "./data"
TEST_DATA_PATH = "./data"
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trainset = torchvision.datasets.CIFAR10(
root=TRAIN_DATA_PATH, train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)
testset = torchvision.datasets.CIFAR10(
root=TEST_DATA_PATH, train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
Files already downloaded and verified Files already downloaded and verified
Let us show some of the training images, for fun.
NUM_IMAGES_PER_ROW = 8
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils
from numpy import typing as npt
def imshow(img: npt.NDArray[np.float32]) -> None:
"""Display an image using matplotlib.
Args:
img: Image to display.
"""
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(testloader)
images, labels = next(dataiter)
# show images
imshow(
torchvision.utils.make_grid(
images[:NUM_IMAGES_PER_ROW], nrow=NUM_IMAGES_PER_ROW
)
)
# print labels
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(NUM_IMAGES_PER_ROW)))
cat ship ship plane frog frog car frog
Define a Convolutional Neural Network¶
Copy the neural network from the Neural Networks section before and modify it to take 3-channel images (instead of 1-channel images as it was defined).
import torch.nn as nn
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Net(nn.Module):
"""Simple Convolutional network for CIFAR-10 classification."""
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 96, 5),
nn.ReLU(),
nn.Conv2d(96, 96, 1),
nn.ReLU(),
nn.MaxPool2d(3, 2),
nn.Conv2d(96, 192, 5),
nn.ReLU(),
nn.Conv2d(192, 192, 1),
nn.ReLU(),
nn.MaxPool2d(3, 2),
nn.Conv2d(192, 192, 3),
nn.ReLU(),
nn.Conv2d(192, 192, 1),
nn.ReLU(),
nn.Conv2d(192, 10, 1),
nn.ReLU(),
)
self.classifier = nn.Linear(40, 10)
def forward(self, x):
x = self.backbone(x)
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = self.classifier(x)
return x
net = Net()
net.to(DEVICE)
_ = net.train()
Define a loss function and optimizer¶
Let's use a Classification Cross-Entropy loss and Adam.
LR_BASE_MODEL = 3e-3
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=LR_BASE_MODEL, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
Train the base model¶
NUM_EPOCHS_BASE_MODEL = 50
import tqdm.auto
for _ in (
epoch_pbar := tqdm.auto.tqdm(range(NUM_EPOCHS_BASE_MODEL))
): # loop over the dataset multiple times
running_loss = 0.0
num_batches = len(trainloader)
for data in (batch_pbar := tqdm.auto.tqdm(trainloader, leave=False)):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss: torch.Tensor = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss
batch_pbar.set_postfix({"loss": loss.item()})
running_test_loss = 0.0
for data in testloader:
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with torch.no_grad():
outputs = net(inputs)
loss: torch.Tensor = criterion(outputs, labels)
running_test_loss += loss
scheduler.step()
epoch_pbar.set_postfix(
{
"loss": running_loss.item() / num_batches,
"test_loss": running_test_loss.item() / len(testloader),
"lr": scheduler.get_last_lr()[0],
}
)
100%| | 50/50 [00:00<?, ?it/s]
# Save the model state dict
torch.save(net.state_dict(), "cifar10_base_model.pt")
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
from stainedglass_core import (
loss as sg_loss,
model as sg_model,
noise_layer as sg_noise_layer,
)
noisy_model = sg_model.NoisyModel(
sg_noise_layer.PatchCloakNoiseLayer,
net,
target_parameter="x",
color_channels=3,
patch_size=16,
scale=(1e-3, 1.0),
percent_to_mask=0.0,
)
noisy_model.to(DEVICE)
# We want to freeze the weights of the base model and only train the
# Stained Glass Transform.
for param in noisy_model.base_model.parameters():
param.requires_grad = False
for param in noisy_model.noise_layer.parameters():
param.requires_grad = True
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
Reconstruct the optimizer using the wrapped model¶
LR_STAINED_GLASS = 3e-3
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
# We have to make sure to pass in the parameters of the NoisyModel and not just
# the base model, otherwise the optimizer will not update the weights of the
# Stained Glass Transform.
optimizer = optim.Adam(
noisy_model.parameters(), lr=LR_STAINED_GLASS, weight_decay=0
)
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.92)
Wrap the loss function with a noise loss wrapper¶
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
# Alpha closer to 0 will prioritize the model's performance, while
# alpha closer to 1 will prioritize the strength of the Stained Glass
# Transform.
ALPHA = 0.7
noisy_criterion, get_component_losses, _ = (
sg_loss.cloak.composite_cloak_loss_factory(
noisy_model, criterion, alpha=ALPHA, respect_std_mask=False
)
)
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
Training¶
NUM_EPOCHS_STAINEDGLASS = 100
for _ in (epoch_pbar := tqdm.auto.tqdm(range(NUM_EPOCHS_STAINEDGLASS))):
running_losses = {
"task_loss": torch.tensor(0.0, device=DEVICE),
"negative_log_mean_loss": torch.tensor(0.0, device=DEVICE),
"composite_loss": torch.tensor(0.0, device=DEVICE),
}
running_test_losses = {
"task_loss": torch.tensor(0.0, device=DEVICE),
"negative_log_mean_loss": torch.tensor(0.0, device=DEVICE),
"composite_loss": torch.tensor(0.0, device=DEVICE),
}
alphas = {}
num_batches = len(trainloader)
for data in (batch_pbar := tqdm.auto.tqdm(trainloader, leave=False)):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
# zero the parameter gradients
optimizer.zero_grad()
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
# forward + backward + optimize
# We need to pass the inputs through the wrapped `noisy_model` instead
# of the `net` model.
outputs = noisy_model(inputs)
loss = noisy_criterion(outputs, labels)
losses = get_component_losses()
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
loss.backward()
optimizer.step()
for key, value in losses.items():
if key in running_losses:
running_losses[key] += value
alphas = {
key: value.item()
for key, value in losses.items()
if key not in running_losses
}
average_losses = {
key: value.item()
for key, value in losses.items()
if key in running_losses
}
lr = {"lr": scheduler.get_last_lr()[0]}
items_to_log = average_losses | lr | alphas
batch_pbar.set_postfix(items_to_log)
for data in testloader:
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with torch.no_grad():
outputs = noisy_model(inputs)
loss = noisy_criterion(outputs, labels)
losses = get_component_losses()
for key, value in losses.items():
if key in running_test_losses:
running_test_losses[key] += value
scheduler.step()
average_losses = {
key: value.item() / num_batches for key, value in running_losses.items()
}
average_test_losses = {
f"test_{key}": value.item() / len(testloader)
for key, value in running_test_losses.items()
}
lr = {"lr": scheduler.get_last_lr()[0]}
items_to_log = average_losses | average_test_losses | lr | alphas
epoch_pbar.set_postfix(items_to_log)
100%| | 100/100 [00:00<?, ?it/s]
# Save the model state dict
torch.save(noisy_model.state_dict(), "cifar10_stained_glass.pt")
Test the network on the test data¶
We have trained the network for a few passes over the training dataset. But we need to check if the network has learnt anything at all.
We will check this by predicting the class label that the neural network outputs, and checking it against the ground-truth. If the prediction is correct, we add the sample to the list of correct predictions.
First, let us display an image from the test set to get familiar.
dataiter = iter(testloader)
images, labels = next(dataiter)
# print images
imshow(
torchvision.utils.make_grid(
images[:NUM_IMAGES_PER_ROW], nrow=NUM_IMAGES_PER_ROW
)
)
print(
"GroundTruth: ",
" ".join(f"{classes[labels[j]]:5s}" for j in range(NUM_IMAGES_PER_ROW)),
)
GroundTruth: cat ship ship plane frog frog car frog
Visualize the images with the Stained Glass Transform¶
Now that we've trained the Stained Glass Transform, let's apply the transform to the test images and visualize the results.
import stainedglass_core.utils
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = images.to(DEVICE), labels.to(DEVICE)
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
# For vision models, the `noise_layer` of the `NoisyModel` is the
# Stained Glass Transform.
stained_glass_transform = noisy_model.noise_layer
transformed_images = stained_glass_transform(images)
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
# Display images and labels
imshow(
torchvision.utils.make_grid(
transformed_images[:NUM_IMAGES_PER_ROW].cpu(), nrow=NUM_IMAGES_PER_ROW
)
)
print(
"GroundTruth: ",
" ".join(f"{classes[labels[j]]:5s}" for j in range(NUM_IMAGES_PER_ROW)),
)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
GroundTruth: cat ship ship plane frog frog car frog
Test the network on transformed images¶
Okay, now let us see what the neural network thinks these examples above are:
# There is a hook that applies the Stained Glass Transform automatically to
# inputs of the model. We can remove this hook temporarily to pass the original
# images to the model.
with stainedglass_core.utils.torch.temporarily_remove_hooks(net):
outputs_untransformed = noisy_model.base_model(images)
outputs_transformed = noisy_model.base_model(transformed_images)
outputs_untransformed
tensor([[-0.8914, -2.1960, -0.6631, ..., -1.7683, -1.2430, -1.8924], [ 3.7042, 4.6633, -3.4808, ..., -4.6254, 4.6014, 2.6636], [ 1.0646, 2.8372, -2.9984, ..., -3.6443, 1.9301, 1.7977], ..., [-1.3585, -1.2781, -1.5252, ..., -2.0195, -1.4466, -1.2950], [ 1.2942, -0.1603, -0.6354, ..., -2.2633, 1.1945, -0.4951], [ 0.3256, 0.4093, -2.5482, ..., -4.0406, 0.3767, -0.1939]], device='cuda:0')
The outputs are logits for the 10 classes. The higher the logit for a class, the more the network thinks that the image is of the particular class.
And we can take the class with the highest logit as the prediction for the class of the image:
_, predicted = torch.max(outputs_untransformed, 1)
print(
"Predicted: ",
" ".join(f"{classes[predicted[j]]:5s}" for j in range(NUM_IMAGES_PER_ROW)),
)
Predicted: horse car horse truck ship truck horse car
Let us look at how the network performs on the whole dataset.
total = 0
correct_untransformed = 0
correct_transformed = 0
# Since we're not training, we don't need to calculate the gradients for our
# outputs.
with (
torch.no_grad(),
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
# We want to remove the hooks that automatically apply the
# Stained Glass Transform to the inputs of the model.
stainedglass_core.utils.torch.temporarily_remove_hooks(net),
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
):
for data in tqdm.auto.tqdm(testloader):
raw_images, labels = data
raw_images, labels = raw_images.to(DEVICE), labels.to(DEVICE)
# calculate outputs by running images through the network
outputs_untransformed = noisy_model.base_model(raw_images)
# the class with the highest energy is what we choose as prediction
_, predicted_untransformed = torch.max(outputs_untransformed.data, 1)
correct_untransformed += (
(predicted_untransformed == labels).sum().item()
)
###############################################
######### BEGIN STAINED GLASS CHANGES #########
###############################################
# apply the Stained Glass Transform to the images
transformed_images = stained_glass_transform(raw_images)
# calculate outputs by running images through the network
outputs_transformed = noisy_model.base_model(transformed_images)
# the class with the highest energy is what we choose as prediction
_, predicted_transformed = torch.max(outputs_transformed.data, 1)
correct_transformed += (predicted_transformed == labels).sum().item()
###############################################
########## END STAINED GLASS CHANGES ##########
###############################################
total += labels.size(0)
print(
"Accuracy of the network on the (raw) test images: "
f"{100 * correct_untransformed // total} %"
)
print(
"Accuracy of the network on the (transformed) test images: "
f"{100 * correct_transformed // total} %"
)
0%| | 0/10 [00:00<?, ?it/s]
Accuracy of the network on the (raw) test images: 43 % Accuracy of the network on the (transformed) test images: 38 %
We see that our model with or without Stained Glass Transform performs way better than random chance (~10% accuracy).
The gap between the performance of raw inputs vs transformed inputs is large can be closed by training the Stained Glass Transform for longer or by tuning the hyperparameters. Additionally, after the Stained Glass Transform is trained, the base model can be finetuned on the transformed inputs to further improve performance. Finetuning the base model is not covered in this tutorial.