Skip to content

Stained Glass Core in 15 minutes (Quickstart guide)

Stained Glass Core is a PyTorch library that enables the creation and deployment of Stained Glass Transform. It easily integrates with existing PyTorch training and inference pipelines with minimal code changes. This guide highlights the key concepts required to use Stained Glass Core in your PyTorch project.

Creating a Stained Glass Transform

For a more comprehensive tutorial on integrating Stained Glass Transform with an existing training pipeline see Creating a Stained Glass Transform for Image Classification or Creating a Stained Glass Transform for a Large Language Model.

Code changes

Integrating Stained Glass Core into an existing training workflow typically requires only four simple changes (five in the case of language models) described briefly in the table below. For a more detailed explanation of each change, see the following sections.

Description Sample Base Code Sample Code with Stained Glass Core
Wrap a PyTorch model with a Stained Glass Transform Model (including extra hyperparameters) model = base_model model = NoisyModel(base_model, noise_layer_class=CloakNoiseLayerOneShot, scale=(.001, 1.0), percent_to_mask=0.6)
Wrap a loss function with the Stained Glass Transform Loss (including extra hyperparameters) loss_func = base_loss loss_func = composite_patched_negative_log_mean_factory(model, base_loss, alpha=0.4)
For language models only: Wrap tokenizer with a Stained Glass Tokenizer Wrapper tokenizer = base_tokenizer tokenizer_wrapper = TokenizerWrapper(base_tokenizer, model_type=LlamaModel, include_labels=True)

Wrap a PyTorch model with a Stained Glass Transform Model

In order to generate a Stained Glass Transform, the base model (the user-provided PyTorch model) must be wrapped with an appropriate NoisyModel (sub-)class, which will manage and apply a transform and the components necessary to train a transform.

NoisyModel is parameterized by different types of Stained Glass Transform (which are referred to as NoiseLayer classes in the API), each with different tunable parameters. See each type's API reference for a complete list of their parameters. See below for the types of Stained Glass Transform and their use cases.

For Large Language Models, use NoiseMaskedNoisyTransformerModel as the NoisyModel subclass and TransformerCloak as the NoiseLayer class. These classes perform additional optimizations for training Stained Glass Transform for Large Language Models. See the API reference or LLM Training Recommendation Guide for a complete list of their parameters.

The base model in the example below could be any model from the Hugging Face Transformers library.

import transformers
from stainedglass_core.model.noisy_transformer_masking_model import NoiseMaskedNoisyTransformerModel

base_model = transformers.AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.2")
noisy_model = NoiseMaskedNoisyTransformerModel(
    noise_layer_class=transformer_cloak.TransformerCloak,
    base_model=base_model,
    target_layer="model.embed_tokens",
    config_path="mistralai/Mistral-7B-v0.2"
    scale=(1e-8, 1.0),
    transformer_type=transformers.MistralModel,
)

For Computer Vision models, the NoisyModel class itself is appropriate. In the example below, we will use NoisyModel with the CloakNoiseLayerOneShot Stained Glass Transform type. It has two parameters: scale and percent_to_mask. See the API reference for a complete list of types and their parameters.

The base model here can be any PyTorch Module. For this example, we will use a ResNet model from torchvision.

import torchvision
from stainedglass_core.model import NoisyModel

base_model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
noisy_model = NoisyModel(
    CloakNoiseLayerOneShot,
    base_model,
    scale=(.001, 1.0),
    percent_to_mask=0.6
)

Once the model has been wrapped, it is functionally interchangeable with the base model. The wrapped model automatically applies the Stained Glass Transform to the input data.

Wrap a loss function with the Stained Glass Transform Loss

In order to train a Stained Glass Transform, a special loss function must be used which maximizes the strength of the transform while minimizing the impact on the model's performance.

Training a Stained Glass Transform for Large Language Models uses a loss function that compares the activations of the base model with and without transformed inputs in lieu of a user-provided loss function to measure the model's performance.

The loss function is created using the distillation_loss_factory function. This function returns a loss function that can be used to train the Stained Glass Transform, as well as functions to retrieve the losses and hyperparameters of the transform, useful for logging.

See the API reference or LLM Training Recommendation Guide for a complete list of their parameters.

from stainedglass_core import loss as sg_loss

distillation_loss, get_losses, get_hyperparameters = (
    sg_loss.distillation.distillation_loss_factory(
        noisy_model,
        distillation_layer_index=12,
        alpha=0.54,
        std_log_ratio_loss_weight=0.01,
        input_embedding_similarity_loss_weight=0.75,
        distillation_layer_cosine_distance_loss_weight=12.0,
    )
)

...

loss_value = distillation_loss(loss_mask)

This created loss function does not take in any model outputs or targets, but does require a loss_mask, which can be obtained from the Tokenizer Wrapper.

In order to train a Stained Glass Transform for computer vision models, the Stained Glass Transform loss function wraps around a user-provided loss function using the hook_loss_wrapper. Unlike in the Large Language Model case, this loss depends on the outputs of the computer vision model.

This function will return a wrapped loss function that will calculate a combined loss, which includes the loss from the base loss function and the loss from the Stained Glass Transform. Some pre-defined loss wrappers are defined in stainedglass_core.loss.cloak.

The alpha tunable hyperparameter of the composite_loss function controls the priority of the transform loss over the base loss. Low values of alpha (close to 0) will prioritize the base loss, meaning the transform may learn slower, but model performance will be preserved. High values of alpha (close to 1) may prioritize the transform loss, meaning the transform may learn faster, but model performance could be worsened. This is a key parameter to tune when training the Stained Glass Transform.

from torch import nn

from stainedglass_core import loss as sg_loss

base_loss_func = nn.CrossEntropyLoss()
loss_func = sg_loss.cloak.composite_cloak_loss_factory(noisy_model, base_loss_func, alpha=0.4)

For language models only: Wrap tokenizer with a Stained Glass Tokenizer Wrapper

Training language models also requires wrapping the tokenizer with a TokenizerWrapper. This wrapper uses the tokenizer and a template function to create the necessary inputs for the Stained Glass Transform.

The tokenizer wrapper requires the tokenizer used for the base model.

See the API reference or LLM Training Recommendation Guide for a complete list of their parameters.

When called, the wrapped tokenizer takes in a sequence of dictionaries containing "role" and "content" strings (compatible with chat representations in Hugging Face Transformer chat templates or with OpenAI API chat prompts) and returns a dictionary containing the tokenized inputs for the base model as well as the noise_mask, which describes which tokens are to be transformed.

import transformers
from stainedglass_core.huggingface import tokenization_utils as sg_tokenization_utils


base_tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.2")
tokenizer_wrapper = sg_tokenization_utils.TokenizerWrapper(
    base_tokenizer,
    model_type=type(base_model),
    include_labels=True,
    ignore_prompt_loss=True,
    prompt_type=sg_tokenization_utils.PromptType.CHAT
)

...

inputs = tokenizer_wrapper(
    [
        {"role": "system", "content": "You are a helpful chat bot who answer questions."},
        {"role": "user", "content": "What is the capital of France?"},
    ]
)

When called, the wrapped tokenizer takes in a dictionary containing instruction, system_prompt, context, and response strings (in the case of instruction prompting) and returns a dictionary containing the tokenized inputs for the base model as well as the noise_mask, which describes which tokens are to be transformed. It could optionally also return a loss_mask which describes which tokens are to be used for the loss calculation.

import transformers
from stainedglass_core.huggingface import tokenization_utils as sg_tokenization_utils


base_tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.2")
tokenizer_wrapper = sg_tokenization_utils.TokenizerWrapper(
    base_tokenizer,
    model_type=type(base_model),
    include_labels=True,
    ignore_prompt_loss=True,
    prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION
)

...

inputs = tokenizer_wrapper(
    {
        "instruction": "What is the capital of France?",
        "system_prompt": "You are a helpful chat bot who answer questions.",
        "context": "",
        "response": ""
    }
)

Inference with Stained Glass Transform

For inference, the Stained Glass Transform is applied on the client side (within their zone of trust). The transformed data is then sent to the model provider (server side), where the rest of the model uses that transformed data.

Because the model provider creates the Stained Glass Transform, they must first prepare it for distribution to the client. After this distribution, there are minimal changes needed on the client side and often little-to-no changes needed on the server side, all of which described below.

Model provider preparation

The model provider after creating a Stained Glass Transform must distribute it to the client. This distribution includes the Stained Glass Transform weights, and, if the model is a language model, may also include a tokenizer, a template function, and input embeddings weights for the base model.

# Create a Stained Glass Transform Text Client object, then save it.
client = StainedGlassTransformForText(
    model=noisy_model,
    tokenizer_wrapper=tokenizer_wrapper,
    parameter_names=embedding_parameters_names,
)

# This file will be distributed to the client.
client.save_pretrained("stainedglass_transform.pth")
# Extract the Stained Glass Transform from the `NoisyModel`.
noise_layer = noisy_model.noise_layer

# This file will be distributed to the client.
torch.save(noise_layer.state_dict(), "noise_layer.pth")

Client Side

The client first instantiates a Stained Glass Transform object, loading the weights from the model provider.

Before sending the inputs to the model provider, the client must first apply that Stained Glass Transform to those inputs.

# Load the Stained Glass Transform from the file distributed by the model
# provider. The client does not need to know the hyperparameters of the
# transform in this case.
sgt = StainedGlassTransformForText.from_pretrained(
    "stainedglass_transform.pth"
)

# Apply the Stained Glass Transform to the input data
transformed_input_embeddings = sgt(
    {
        instruction="What is the capital of France?",
        system_prompt="",
        context = "",
        response = ""
    }
)
# Load the Stained Glass Transform from the file distributed by the model provider.
# The client must know the type and hyperparameters of the transform.
stained_glass_transform = CloakNoiseLayerOneShot(percent_to_mask=0.4)
stained_glass_transform.load_state_dict(torch.load("noise_layer.pth"))

# Apply the Stained Glass Transform to the input data
transformed_input = stained_glass_transform(input)

The transformed input is then sent to the model provider for inference.

Server Side

The model provider receives the transformed input.

The transformed inputs received from the Stained Glass Transform are input embeddings, not text.

Tip

If the model provider can already accept input embeddings directly for inference/generation, then no changes are needed. Otherwise, the model provider must add some way to consume these input embeddings.

Info

Hugging Face models can accept input embeddings directly for inference/generation, using the inputs_embeds argument to forward or generate.

The transformed inputs are compatible with image data (as a tensor). The model provider can use these transformed inputs directly for inference, as if they were images.

Types of Stained Glass Transform

Stained Glass Transform comes in many types, each with different tunable parameters. See the API reference for hyperparameters of each type.

Use case Recommended Stained Glass Transform type
Language Models TransformerCloak
Computer Vision Models CloakNoiseLayerOneShot or PatchCloakNoiseLayer