Skip to content

Stained Glass Core in 15 minutes (Quickstart guide)

Stained Glass Core is a PyTorch library that enables the creation and deployment of Stained Glass Transform. It easily integrates with existing PyTorch training and inference pipelines with minimal code changes. This guide highlights the key concepts required to use Stained Glass Core in your PyTorch project.

Creating a Stained Glass Transform

For a more comprehensive tutorial on integrating Stained Glass Transform with an existing training pipeline see Creating a Stained Glass Transform for Image Classification or Creating a Stained Glass Transform for a Large Language Model.

Code changes

Integrating Stained Glass Core into an existing training workflow typically requires only four simple changes (five in the case of language models) described briefly in the table below. For a more detailed explanation of each change, see the following sections.

Description Sample Base Code Sample Code with Stained Glass Core
Wrap a PyTorch model with a Stained Glass Transform Model (including extra hyperparameters) model = base_model model = NoisyModel(base_model, noise_layer_class=CloakNoiseLayerOneShot, input_shape=(1,3,224,224), scale=(.001, 1.0), percent_to_mask=0.6)
Wrap a loss function with the Stained Glass Transform Loss (including extra hyperparameters) loss_func = base_loss loss_func = model.noise_loss_wrapper(base_loss, alpha=0.4)
For language models only: Wrap tokenizer with a Stained Glass Tokenizer Wrapper tokenizer = base_tokenizer tokenizer_wrapper = TokenizerWrapper(base_tokenizer, model_type=LlamaModel, include_labels=True)
Unwrap model wrapper outputs output = model(input) output = model(input).base_model_output
For computer vision models only: Unwrap loss wrapper outputs loss = loss_func(output, target) loss = loss_func(output, target)["composite_loss"]

Wrap a PyTorch model with a Stained Glass Transform Model

In order to generate a Stained Glass Transform, the base model (the user-provided PyTorch model) must be wrapped with an appropriate NoisyModel (sub-)class, which will manage and apply a transform and the components necessary to train a transform.

NoisyModel is parameterized by different types of Stained Glass Transform (which are referred to as NoiseLayer classes in the API), each with different tunable parameters. See each type's API reference for a complete list of their parameters. See below for the types of Stained Glass Transform and their use cases.

For Large Language Models, use NoiseMaskedNoisyTransformerModel as the NoisyModel subclass and TransformerCloak as the NoiseLayer class. These classes perform additional optimizations for training Stained Glass Transform for Large Language Models. See the API reference or LLM Training Recomendation Guide for a complete list of their parameters.

The base model in the example below could be any model from the Hugging Face Transformers library.

import transformers
from stainedglass_core.model.noisy_transformer_masking_model import NoiseMaskedNoisyTransformerModel

base_model = transformers.AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.2")
noisy_model = NoiseMaskedNoisyTransformerModel(
    noise_layer_class=transformer_cloak.TransformerCloak,
    base_model=base_model,
    input_shape=(-1, 4096),
    target_layer="model.embed_tokens",
    config_path="mistralai/Mistral-7B-v0.2"
    scale=(1e-8, 1.0),
    transformer_type=transformers.MistralModel,
)

For Computer Vision models, the NoisyModel class itself is appropriate. In the example below, we will use NoisyModel with the CloakNoiseLayerOneShot Stained Glass Transform type. It has two parameters: scale and percent_to_mask. See the API reference for a complete list of types and their parameters.

The base model here can be any PyTorch Module. For this example, we will use a ResNet model from torchvision.

import torchvision
from stainedglass_core.model import NoisyModel

base_model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
noisy_model = NoisyModel(
    CloakNoiseLayerOneShot,
    base_model,
    scale=(.001, 1.0),
    percent_to_mask=0.6
)

Once the model has been wrapped, it can be used almost interchangeably with the base model. The only differences are:

  1. The forward pass will now automatically apply the Stained Glass Transform to the input data.
  2. The outputs of the NoisyModel are wrapped in NoisyModelOutput class, which contains the output of the base model and some additional information for training the Transform.

Warning

The NoiseLayer is applied via a PyTorch forward hook, meaning that the transform is applied during every forward pass of the base model. You must disable this hook to perform a base model forward without the transform.

Wrap a loss function with the Stained Glass Transform Loss

In order to train a Stained Glass Transform, a special loss function must be used which maximizes the strength of the transform while minimizing the impact on the model's performance.

Training a Stained Glass Transform for Large Language Models uses a loss function that compares the activations of the base model with and without transformed inputs in lieu of a user-provided loss function to measure the model's performance.

The loss function is created using the distillation_loss_factory function. This function returns a loss function that can be used to train the Stained Glass Transform, as well as functions to retrieve the losses and hyperparameters of the transform, useful for logging.

See the API reference or LLM Training Recomendation Guide for a complete list of their parameters.

from stainedglass_core import loss as sg_loss

distillation_loss, get_losses, get_hyperparameters = (
    sg_loss.distillation.distillation_loss_factory(
        noisy_model,
        distillation_layer_index=12,
        alpha=0.54,
        std_log_ratio_loss_weight=0.01,
        input_embedding_similarity_loss_weight=0.75,
        distillation_layer_distance_loss_weight=12.0,
    )
)

...

loss_value = distillation_loss(loss_mask)

This created loss function does not take in any model outputs or targets, but does require a loss_mask, which can be obtained from the Tokenizer Wrapper.

In order to train a Stained Glass Transform for computer vision models, the Stained Glass Transform loss function wraps around a user-provided loss function using the noise_loss_wrapper method of the NoisyModel class. Unlike in the Large Language Model case, this loss depends on the outputs of the computer vision model.

This method will return a wrapped loss function that will calculate a combined loss, which includes the loss from the base loss function and the loss from the Stained Glass Transform.

The alpha tunable hyperparameter of the noise_loss_wrapper method controls the priority of the transform loss over the base loss. Low values of alpha (close to 0) will prioritize the base loss, meaning the transform may learn slower, but model performance will be preserved. High values of alpha (close to 1) may prioritize the transform loss, meaning the transform may learn faster, but model performance could be worsened. This is a key parameter to tune when training the Stained Glass Transform.

Optionally, the alpha parameter can be set to None to use an adaptive search algorithm. This is often useful for finding a good starting point for alpha for simple models, but tuning alpha manually afterward will often lead to better results.

from torch import nn

base_loss_func = nn.CrossEntropyLoss()
loss_func = noisy_model.noise_loss_wrapper(base_loss_func, alpha=0.4)

For language models only: Wrap tokenizer with a Stained Glass Tokenizer Wrapper

Training language models also requires wrapping the tokenizer with a TokenizerWrapper. This wrapper uses the tokenizer and a template function to create the necessary inputs for the Stained Glass Transform.

The tokenizer wrapper requires the tokenizer used for the base model.

See the API reference or LLM Training Recomendation Guide for a complete list of their parameters.

When called, the wrapped tokenizer takes in a sequence of dictionaries containing "role" and "content" strings (compatible with chat representations in Hugging Face Transformer chat templates or with OpenAI API chat prompts) and returns a dictionary containing the tokenized inputs for the base model as well as the noise_mask, which describes which tokens are to be transformed.

import transformers
from stainedglass_core.huggingface import tokenization_utils as sg_tokenization_utils


base_tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.2")
tokenizer_wrapper = sg_tokenization_utils.TokenizerWrapper(
    base_tokenizer,
    model_type=type(base_model),
    include_labels=True,
    ignore_prompt_loss=True,
    prompt_type=sg_tokenization_utils.PromptType.CHAT
)

...

inputs = tokenizer_wrapper(
    [
        {"role": "system", "content": "You are a helpful chat bot who answer questions."},
        {"role": "user", "content": "What is the capital of France?"},
    ]
)

When called, the wrapped tokenizer takes in a dictionary containing instruction, system_prompt, context, and response strings (in the case of instruction prompting) and returns a dictionary containing the tokenized inputs for the base model as well as the noise_mask, which describes which tokens are to be transformed. It could optionally also return a loss_mask which describes which tokens are to be used for the loss calculation.

import transformers
from stainedglass_core.huggingface import tokenization_utils as sg_tokenization_utils


base_tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.2")
tokenizer_wrapper = sg_tokenization_utils.TokenizerWrapper(
    base_tokenizer,
    model_type=type(base_model),
    include_labels=True,
    ignore_prompt_loss=True,
    prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION
)

...

inputs = tokenizer_wrapper(
    {
        "instruction": "What is the capital of France?",
        "system_prompt": "You are a helpful chat bot who answer questions.",
        "context": "",
        "response": ""
    }
)

Unwrap model wrapper outputs

The outputs of a NoisyModel are a NoisyModelOutput class, which contains the output of the base model and some additional information for training the transform. In order to use the outputs of the base model, the base_model_output attribute of the NoisyModelOutput class must be accessed.

outputs = noisy_model(input)
base_model_outputs = outputs.base_model_output

Note, however, that the wrapped loss function expects this wrapped output (in the case of computer vision), so the base_model_output attribute generally does not need to be accessed when using the wrapped loss function.

outputs = noisy_model(input)
# Pass in the `NoisyModelOutput` directly to the wrapped loss function
loss = loss_func(outputs, target)["composite_loss"]

Since the model outputs are not used by the Large Language Model loss function, the above disclaimer does not apply.

For computer vision models only: Unwrap loss wrapper outputs

The output of the wrapped loss function is a dictionary containing various components of the loss, including the base model loss (calculated from the base loss function) and the loss from the Stained Glass Transform. The combined loss (contained in the composite_loss key) should be backpropogated via backward().

outputs = model(input)
# Extract the composite loss from the dictionary; the other losses
# may be useful for logging.
loss = loss_func(outputs, target)["composite_loss"]
loss.backward()

This step is unnecessary for Large Language Models, as the distillation loss factory function returns a loss function that returns a single loss tensor, as well as a seperate function to retrieve all other loss components. Backpropogation is done on the single loss tensor, without any need to unpack a dictionary.

Inference with Stained Glass Transform

For inference, the Stained Glass Transform is applied on the client side (within their zone of trust). The transformed data is then sent to the model provider (server side), where the rest of the model uses that transformed data.

Because the model provider creates the Stained Glass Transform, they must first prepare it for distribution to the client. After this distribution, there are minimal changes needed on the client side and often little-to-no changes needed on the server side, all of which described below.

Model provider preparation

The model provider after creating a Stained Glass Transform must distribute it to the client. This distribution includes the Stained Glass Transform weights, and, if the model is a language model, may also include a tokenizer, a template function, and input embeddings weights for the base model.

# Create a Stained Glass Transform Text Client object, then save it.
client = StainedGlassTransformForText(
    model=noisy_model,
    tokenizer_wrapper=tokenizer_wrapper,
    parameter_names=embedding_parameters_names,
)

# This file will be distributed to the client.
client.save_pretrained("stainedglass_transform.pth")
# Extract the Stained Glass Transform from the `NoisyModel`.
noise_layer = noisy_model.noise_layer

# This file will be distributed to the client.
torch.save(noise_layer.state_dict(), "noise_layer.pth")

Client Side

The client first instantiates a Stained Glass Transform object, loading the weights from the model provider.

Before sending the inputs to the model provider, the client must first apply that Stained Glass Transform to those inputs.

# Load the Stained Glass Transform from the file distributed by the model
# provider. The client does not need to know the hyperparameters of the
# transform in this case.
sgt = StainedGlassTransformForText.from_pretrained(
    "stainedglass_transform.pth"
)

# Apply the Stained Glass Transform to the input data
transformed_input_embeddings = sgt(
    {
        instruction="What is the capital of France?",
        system_prompt="",
        context = "",
        response = ""
    }
)
# Load the Stained Glass Transform from the file distributed by the model provider.
# The client must know the type and hyperparameters of the transform.
stained_glass_transform = CloakNoiseLayerOneShot(percent_to_mask=0.4)
stained_glass_transform.load_state_dict(torch.load("noise_layer.pth"))

# Apply the Stained Glass Transform to the input data
transformed_input = stained_glass_transform(input)

The transformed input is then sent to the model provider for inference.

Server Side

The model provider receives the transformed input.

The transformed inputs received from the Stained Glass Transform are input embeddings, not text.

Tip

If the model provider can already accept input embeddings directly for inference/generation, then no changes are needed. Otherwise, the model provider must add some way to consume these input embeddings.

Info

Hugging Face models can accept input embeddings directly for inference/generation, using the inputs_embeds argument to forward or generate.

The transformed inputs are compatible with image data (as a tensor). The model provider can use these transformed inputs directly for inference, as if they were images.

Types of Stained Glass Transform

Stained Glass Transform comes in many types, each with different tunable parameters. See the API reference for hyperparameters of each type.

Use case Recommended Stained Glass Transform type
Language Models TransformerCloak
Computer Vision Models CloakNoiseLayerOneShot or PatchCloakNoiseLayer

Integration with Huggingface Trainers

Much of the code changes described above are abstracted away in our custom HuggingFace Trainer. If you use a HuggingFace Trainer, see the HuggingFace Trainer Integration tutorial for more information.