Deploying Stained Glass Transform (Text)¶
This tutorial demonstrates how to deploy a Stained Glass Transform for client applications.
This has three main sections:
- Model provider preparation (after training): We prepare a deployable Stained Glass Transform from a previously trained
NoisyModel
. - Client transforms inputs: The client uses Stained Glass Transform to transform inputs before sending them to the model provider.
- Model provider generation: The model provider uses the transformed inputs from the client to generate text using the base model.
Pre-requisites:
- The Stained Glass Transform should already be trained. To follow this tutorial exactly, the
NoisyModel
weights (state dict) should be saved in a filenoisy_model_state_dict.pth
.
For more information about using a Stained Glass Transform, please refer to the API Reference.
Model provider preparation¶
Instantiating the wrapped model and tokenizer¶
Warning
This section (instantiating the base model, NoisyModel, and wrapped tokenizer) is exactly the same as during training. If you wish to prepare a Stained Glass Transform for inference in the same script/pipeline as training, you should skip this section and just use the pre-existing wrapped tokenizer and NoisyModel.
This section is only included in this notebook so that it can run separately from the training notebook.
We must first instantiate the base model and tokenizer that we used during training, and then wrap them in a NoisyModel
and TokenizerWrapper
respectively.
import torch
import transformers
from stainedglass_core import transform as sg_transform
from stainedglass_core.huggingface.tokenization_utils import (
tokenizer_wrapper as sg_tokenizer_wrapper,
)
from stainedglass_core.model import noisy_transformer_masking_model
from stainedglass_core.noise_layer import transformer_cloak
[2024-06-28 15:05:43,080] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BASE_MODEL_PATH = "/models/huggingface/mistralai/Mistral-7B-Instruct-v0.2"
TRANSFORM_LAYER_CHECKPOINT = "noisy_model_state_dict.pth"
def base_model_factory() -> transformers.MistralForCausalLM:
"""Create a new instance of the base model.
We will instantiate the base model multiple times in this tutorial, so
wrapping it in a factory function helps clean up the code.
Returns:
The base model.
"""
return transformers.AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH)
def tokenizer_factory() -> transformers.LlamaTokenizer:
"""Create a new instance of the tokenizer.
We will instantiate the tokenizer multiple times in this tutorial, so
wrapping it in a factory function helps clean up the code.
Returns:
The tokenizer.
"""
return transformers.AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
def load_noisy_model_state_dict(
noisy_model: noisy_transformer_masking_model.NoiseMaskedNoisyTransformerModel,
) -> None:
"""Load the state dict of the noisy model from a checkpoint.
Args:
noisy_model: The noisy model to load the state dict.
"""
noisy_model.load_state_dict(
torch.load(TRANSFORM_LAYER_CHECKPOINT, map_location=DEVICE),
strict=False,
)
base_model = base_model_factory()
embedding_size = base_model.config.hidden_size
noisy_model = noisy_transformer_masking_model.NoiseMaskedNoisyTransformerModel(
noise_layer_class=transformer_cloak.TransformerCloak,
base_model=base_model,
transformer_type=transformers.MistralModel,
scale=(1e-8, 1.0),
config_path=BASE_MODEL_PATH,
target_layer="model.embed_tokens",
directly_learn_stds=True,
use_causal_mask=True,
rho_init=0,
)
# Load the state dict
load_noisy_model_state_dict(noisy_model)
tokenizer = tokenizer_factory()
tokenizer_wrapper = sg_tokenizer_wrapper.TokenizerWrapper(
tokenizer=tokenizer,
model_type=type(base_model),
include_labels=False,
ignore_prompt_loss=False,
prompt_type=sg_tokenizer_wrapper.PromptType.INSTRUCTION,
)
Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]
Creating the Stained Glass Transform¶
The StainedGlassTransformForText
object allows an easy interface to pass in text data and get the transformed input embeddings. It also provides an easy way to save itself to disk, and allow the client (not the model provider) to load it in their environment. This is useful for deployment, as the client can load the Stained Glass Transform without needing to know the details of how it was trained.
The StainedGlassTransformForText
object is created by passing in the wrapped model and wrapped tokenizer used during training.
Because the Stained Glass Transform is applied to the input embeddings for the model, it will automatically tokenize the input and pass it to the base model embedding layer before applying the transform. To avoid saving the entire base model, during construction the Stained Glass Transform object will automatically infer the minimal parameters from the base model that are needed to generate the input embeddings. These parameters are saved in the Stained Glass Transform object, and are used to generate the embeddings during inference. No other parameters from the base model are shared with the client.
stainedglass_transform = sg_transform.StainedGlassTransformForText(
model=noisy_model, tokenizer_wrapper=tokenizer_wrapper
)
# These are the parameters that the Stained Glass Transform will include:
stainedglass_transform.parameter_names_relative_to_client
['truncated_module.module.base_model.model.embed_tokens.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.self_attn.q_proj.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.self_attn.k_proj.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.self_attn.v_proj.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.self_attn.o_proj.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.mlp.gate_proj.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.mlp.up_proj.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.mlp.down_proj.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.input_layernorm.weight', 'truncated_module.module.noise_layer.mean_estimator.module.transformer.layers.0.post_attention_layernorm.weight', 'truncated_module.module.noise_layer.mean_estimator.module.linear.weight', 'truncated_module.module.noise_layer.mean_estimator.module.linear.bias', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.self_attn.q_proj.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.self_attn.k_proj.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.self_attn.v_proj.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.self_attn.o_proj.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.mlp.gate_proj.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.mlp.up_proj.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.mlp.down_proj.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.input_layernorm.weight', 'truncated_module.module.noise_layer.std_estimator.module.transformer.layers.0.post_attention_layernorm.weight', 'truncated_module.module.noise_layer.std_estimator.module.linear.weight', 'truncated_module.module.noise_layer.std_estimator.module.linear.bias']
Testing inference with the Stained Glass Transform¶
Obtaining the transformed embeddings for a given input is as simple as calling the Stained Glass Transform object. It will internally handle tokenization, embedding generation, and transformation.
transformed_input_embeddings = stainedglass_transform(
{
"instruction": "What is the capital of France?",
"system_prompt": "",
"context": "",
"response": "",
}
)
transformed_input_embeddings
tensor([[[-4.3640e-03, -1.0633e-04, -5.6152e-03, ..., -5.0545e-05, -1.1520e-03, 1.5926e-04], [-4.3640e-03, -1.8954e-05, -1.7853e-03, ..., 3.5858e-04, 4.0588e-03, 4.5204e-04], [ 1.4496e-04, 5.0354e-04, -2.3499e-03, ..., -2.5024e-03, 3.2349e-03, -2.8229e-03], ..., [ 1.4496e-04, 5.0354e-04, -2.3499e-03, ..., -2.5024e-03, 3.2349e-03, -2.8229e-03], [-4.1504e-03, -1.7548e-03, 3.7231e-03, ..., -1.2589e-04, -9.2697e-04, 3.2196e-03], [-1.7471e-03, 1.0300e-03, 3.7432e-05, ..., 1.1826e-03, 3.7003e-04, 3.2425e-04]]], grad_fn=<AddBackward0>)
Saving the Stained Glass Transform to disk¶
The Stained Glass Transform object can be saved to disk using the save_pretrained
method.
Warning
This method uses pickle internally, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Never load data that could have come from an untrusted source in an unsafe mode, or that could have been tampered with. Only load data you trust.
# For this tutorial, we will save the Stained Glass Transform in a temporary file.
import tempfile
temporary_file = tempfile.NamedTemporaryFile()
FILE_PATH = temporary_file.name
stainedglass_transform.save_pretrained(FILE_PATH)
Clean up the model provider's copy of the NoisyModel¶
We do this for memory management purposes in a limited environment in this notebook. This also emphasizes that the client can load the Stained Glass Transform from disk without any need for the model provider's NoisyModel.
del noisy_model
del base_model
del stainedglass_transform
del tokenizer_wrapper
del tokenizer
Client¶
Loading the Stained Glass Transform from disk¶
Warning
This method uses pickle internally, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Never load data that could have come from an untrusted source in an unsafe mode, or that could have been tampered with. Only load data you trust.
from stainedglass_core import transform as sg_transform
loaded_stainedglass_transform = (
sg_transform.StainedGlassTransformForText.from_pretrained(FILE_PATH).eval()
)
Inference with the Stained Glass Transform¶
Generating the transformed embeddings for a given input is as simple as calling the Stained Glass Transform object. It will internally handle tokenization, embedding generation, and transformation.
transformed_input_embeddings = loaded_stainedglass_transform(
{
"instruction": "What is the capital of France?",
"system_prompt": "",
"context": "",
"response": "",
}
)
transformed_input_embeddings
tensor([[[-4.3640e-03, -1.0633e-04, -5.6152e-03, ..., -5.0545e-05, -1.1520e-03, 1.5926e-04], [-4.3640e-03, -1.8954e-05, -1.7853e-03, ..., 3.5858e-04, 4.0588e-03, 4.5204e-04], [ 1.4496e-04, 5.0354e-04, -2.3499e-03, ..., -2.5024e-03, 3.2349e-03, -2.8229e-03], ..., [ 1.4496e-04, 5.0354e-04, -2.3499e-03, ..., -2.5024e-03, 3.2349e-03, -2.8229e-03], [-4.1504e-03, -1.7548e-03, 3.7231e-03, ..., -1.2589e-04, -9.2697e-04, 3.2196e-03], [-1.7471e-03, 1.0300e-03, 3.7432e-05, ..., 1.1826e-03, 3.7003e-04, 3.2425e-04]]], grad_fn=<AddBackward0>)
These transformed input embeddings are then sent from the client to the model provider for generation. This may be, for example, via some API. For the purposes of this notebook, we will simply save the input embeddings in a variable and the model provider's base model will use the variable below.
Server¶
Instantiate the base model and tokenizer¶
base_model = base_model_factory().to(DEVICE).eval()
tokenizer = tokenizer_factory()
Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]
Perform generation using transformed inputs¶
Note
The model hosted by the provider must have some method of consuming input embeddings directly. For example, Hugging Face models can directly accept inputs_embeds
as an argument to the generate
method.
The transformed embeddings received from the client can be passed directly to the base model for generation, using the same mechanism of consuming untransformed input embeddings.
# Hugging Face models can accept `inputs_embeds` as input in lieu of
# `input_ids`.
tokens_shape = transformed_input_embeddings.shape[:-1]
attention_mask = torch.ones(*tokens_shape, dtype=torch.bool)
output_tokens = base_model.generate(
inputs_embeds=transformed_input_embeddings.to(DEVICE),
attention_mask=attention_mask.to(DEVICE),
max_new_tokens=64,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
renormalize_logits=True,
temperature=0.6,
top_k=5000,
top_p=0.9,
repetition_penalty=1.0,
do_sample=True,
)
tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
['2. The capital city of France is Paris. Paris is one of the most famous cities in the world and is known for its art, fashion, gastronomy, and culture. It is also home to many iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-']
Clean up temporary files¶
temporary_file.close()