Creating a Stained Glass Transform for a Large Language Model¶
This tutorial demonstrates how to train a Stained Glass Transform for an LLM using stainedglass_core
and PyTorch Lightning.
Overview¶
In this notebook, we demonstrate using stainedglass_core
and PyTorch Lightning to train a Stained Glass Transform (SGT) for a causal large language model. Specifically, we use MistralForCausalLM
, but any decoder-only Hugging Face transformers
model (e.g. LlamaForCausalLM
, GemmaForCausalLM
, CohereForCausalLM
, Starcoder2ForCausalLM
, etc.) could be used instead.
Info
This notebook is structured as a single LightningModule
class split across multiple cells, interspersed with explanations. You will see StainedGlassDistillationLightningModule
inherit from itself to facilitate this. If copying and pasting this code elsewhere, it is acceptable to combine all the methods into a single class definition.
SGT works by transforming input embeddings such that they are not naively reconstructable, meaning each transformed embedding's nearest untransformed embedding is not the embedding it was generated from. Importantly, the transformed embeddings are still interpretable by the causal LM for which the SGT is trained. One of the main metrics used to evaluate the strength of an SGT is the percentage of embeddings that it transforms that are not naively reconstructable, colloquially called "obfuscation score".
As an example:
input_text | reconstructed_input_text | obfuscation_score | generated_text_from_transformed_embeddings |
---|---|---|---|
[INST] You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old. Write a dialog based on this summary: Griff and Tea will meet in 10 minutes to have a walk in the park. [/INST] |
[INST] You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old. Write a dialog based on this summary: Griff and Tea will meet in 10 minutes to have a walk in the park. [/INST] |
0% | Sure thing, little buddy! Griff and Tea are going to meet each other in just a short while. You know how when mommy or daddy meets a friend at the park? It's kind of like that. Griff and Tea are friends, and they like to go for walks together in the park. So, they're going to meet up there soon, and then they'll have a fun time exploring the park together. Is that clear? Do you want to draw a picture of Griff and Tea meeting in the park? |
[INST] You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old. Write a dialog based on this summary: Griff and Tea will meet in 10 minutes to have a walk in the park. [/INST] |
[INST]},NdEx();ⵉ того { }";};>?[< to },),{ NdEx')">){]);>?[<";');( },());]); });>?[<]); };� */},}; }, );});',];};}); });'); // */ [/INST] | 100% | Sure thing, little buddy! In a little while, like when the sand in the hourglass is all the way down and we've counted ten apples on our fingers, Griff and Tea will come together for a fun adventure. They're going to take a walk in the park, just like we do when we go for a stroll with mommy or daddy. They'll see pretty flowers, hear the birds singing, and maybe even spot some squirrels or rabbits. So, let's get excited for their special park walk! |
To train an SGT for a causal LM, we use a distillation approach which computes losses between the transformed and untransformed activations. This teaches the SGT how to apply its transformation based on how it affects the base model's activations. The composite distillation is a weighted sum of the following losses:
std_log_ratio_loss
: Maximizes the SGT's standard deviations (its transformation strength).input_embedding_similarity_loss
: Minimizes the similarity of transformed input embeddings to the untransformed input embeddings.distillation_layer_distance_loss
: Maximizes the similarity of the transformed distillation layer embeddings to the untransformed distillation layer embeddings.
from __future__ import annotations
import os
import warnings
from collections.abc import Mapping
from typing import Any, Callable, Final
import datasets
import lightning
import lightning.pytorch.loggers
import lightning.pytorch.utilities.combined_loader
import numpy as np
import psutil
import torch
import torch.distributed
import torch.optim
import torch.utils.data
import torch.utils.tensorboard
import transformers
import transformers.modeling_outputs
from stainedglass_core import (
huggingface,
loss as sg_loss,
metrics as sg_metrics,
utils as sg_utils,
)
from stainedglass_core.huggingface import (
tokenization_utils as sg_tokenization_utils,
)
from stainedglass_core.huggingface.data import data_collator as sg_data_collator
from stainedglass_core.integrations import (
lightning as sg_lightning,
torchmetrics as sg_torchmetrics,
)
from stainedglass_core.model import noisy_transformer_masking_model
from stainedglass_core.noise_layer import transformer_cloak
/home/matthew/.conda/envs/core310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Model and Loss Function Initialization¶
First, we must create the base casual LM. We are not finetuning it, so we should put it in eval
mode and set requires_grad
to False
on all its parameters to disable storing gradients.
To create an SGT, provide the base model to the NoisyModel
subclass (NoiseMaskedNoisyTransformerModel
for causal LMs) along with the noise_layer_class
(TransformerCloak
for causal LMs). The only target_layer
that is currently supported is the base model's input embedding layer. Lastly provide all TransformerCloak
constructor arguments. See Recommendations for Training a Stained Glass Transform for a Causal Language Model for an in-depth explanation of SGT hyperparameters.
The weights of the distillation losses are configured by the parameters to distillation_loss_factory
:
alpha
: The interpolation factor between the distillation loss (maximizing model similarity) and the SGT loss (maximizing transformation strength). Should be in the range [0, 1], where 0 corresponds to higher model similarity and 1 corresponds to higher transformation strength.std_log_ratio_loss_weight
: The weight of the loss component which maximizes the SGT's standard deviations (transformation strength).input_embedding_similarity_loss_weight
: The weight of the loss component which minimizes the similarity of the input embeddings.distillation_layer_distance_loss_weight
: The weight of the loss component which maximizes the similarity of the distillation layer embeddings.
Two other notable functions we apply while setting up the model include:
get_transformed_output_factory
: Returns a function which returns the transformed embeddings from the most recent forward pass, which can be used to measure obfuscation scores.get_applied_transform_components_factory
: Returns a function which returns the means and standard deviations from the most recent forward pass, which can be used to observe changes in the distribution of the means and standard deviations as training progresses.
PRETRAINED_MODEL_NAME_OR_PATH: Final[str] = (
"/models/mistralai/Mistral-7B-Instruct-v0.2"
)
TRUNCATED_LAYER_INDEX: Final[int] = 12
class StainedGlassDistillationLightningModule(lightning.LightningModule):
def __init__(self, max_generation_examples: int) -> None:
super().__init__()
self.max_generation_examples = max_generation_examples
self.generated_text_table = sg_torchmetrics.TableMetric()
def configure_model(self) -> None:
"""Construct the base model, the SGT, and the loss function."""
base_model = transformers.MistralForCausalLM.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
assert isinstance(base_model, transformers.MistralForCausalLM)
self.noisy_model = (
noisy_transformer_masking_model.NoiseMaskedNoisyTransformerModel(
noise_layer_class=transformer_cloak.TransformerCloak,
base_model=base_model,
target_layer="model.embed_tokens",
truncated_layer_index=TRUNCATED_LAYER_INDEX,
scale=(1e-8, 1.0),
shallow=1.0,
mean_dropout=0.1,
std_dropout=0.1,
config_path=PRETRAINED_MODEL_NAME_OR_PATH,
use_causal_mask=True,
transformer_type=transformers.MistralModel,
directly_learn_stds=True,
rho_init=0.0,
seed=0,
)
)
self.distillation_loss, self.get_losses, self.get_hyperparameters = (
sg_loss.distillation.distillation_loss_factory(
self.noisy_model,
distillation_layer_index=TRUNCATED_LAYER_INDEX,
alpha=0.54,
std_log_ratio_loss_weight=0.01,
input_embedding_similarity_loss_weight=0.75,
distillation_layer_cosine_distance_loss_weight=12.0,
)
)
self.get_transformed_embeddings = (
self.noisy_model.noise_layer.get_transformed_output_factory()
)
self.get_applied_transform_components = self.noisy_model.noise_layer.get_applied_transform_components_factory()
Optimizer Initialization¶
Disable weight decay for SGT estimator linear layers
When configuring the SGT optimizer, it is important to disable weight_decay
on the final linear layers of the std_estimator
and mean_estimator
modules. Non-zero weight_decay
may significantly slow convergence or prevent learning altogether.
sg_utils.torch.optim.Freeze(["noise_layer"])
sg_utils.torch.optim.Unfreeze(["base_model"])
class StainedGlassDistillationLightningModule(
StainedGlassDistillationLightningModule
):
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Construct the optimizer."""
param_group_builder = sg_utils.torch.optim.ParamGroupBuilder(
param_groups={
"noise_layer.*_estimator.module.linear": {
"weight_decay": 0.0,
}
},
freeze=sg_utils.torch.optim.Freeze(["base_model"]),
)
return torch.optim.AdamW(
params=param_group_builder(self.noisy_model),
lr=3e-5,
amsgrad=False,
betas=(0.9, 0.95),
eps=1e-5,
weight_decay=0.1,
)
Model Truncation¶
Only base model parameters up to the distillation_layer_index
(configured by distillation_loss_factory
) have any contribution to our distillation loss. Before each training epoch, we can use truncate_and_offload
to remove all decoder layers after the distillation layer from the base model and offload them to the CPU. This both reduces training memory usage and improves model runtime performance, and in practice allows us to utilize higher batch sizes and achieve faster training.
These decoder layers must be restored to perform generation, which we can do do using restore_and_load
before each validation epoch.
class StainedGlassDistillationLightningModule(
StainedGlassDistillationLightningModule
):
def on_train_epoch_start(self) -> None:
"""Truncate and offload decoder layers after `distillation_layer_index`."""
self.noisy_model.truncate_and_offload()
def on_validation_epoch_start(self) -> None:
"""Restore and reload the decoder layers to perform generation."""
self.noisy_model.restore_and_load()
Distillation Training¶
The necessary settings for computation of the distillation loss can be temporarily enabled via the distillation_context
context manager. Attempting to compute the distillation loss without this context.
The noise_mask
is built by the SGT tokenization pipeline (more on this later) and should be provided directly to the base model. The noise_mask
selects the tokens in input_ids
whose embeddings will be transformed by the SGT. Only special prompt token sequences at specific positions such as any padding, the bos token, the eos token, and [INST]
or [/INST]
in the case of Mistral remain untransformed. The noise_mask
is required by TransformerCloak
, and calling the base model without providing a noise_mask
will fail.
The loss_mask
is also built by the SGT tokenization pipeline, and is the only argument to the distillation_loss
. It selects the tokens over which to maximize the similarity of the transformed distillation layer embeddings to the untransformed distillation layer embeddings.
We also demonstrate using get_transformed_embeddings
to compute the obfuscation scores per batch item.
class StainedGlassDistillationLightningModule(
StainedGlassDistillationLightningModule
):
def forward(
self, **kwargs: Any
) -> transformers.modeling_outputs.CausalLMOutputWithPast:
"""Set the distillation context and perform a base model forward pass."""
with self.noisy_model.distillation_context():
return self.noisy_model(**kwargs)
def training_step(
self,
batch: sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
batch_idx: int,
) -> torch.Tensor:
"""Compute the distillation loss and calculate batch-level metrics."""
_ = self(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
use_cache=True,
noise_mask=batch["noise_mask"],
)
loss = self.distillation_loss(batch["loss_mask"])
transformed_embeddings = self.get_transformed_embeddings()
reconstructed_ids = self.noisy_model.reconstruct_ids_from_embeddings(
transformed_embeddings
)
obfuscation_scores = sg_metrics.percentage_changed_ids(
batch["input_ids"],
reconstructed_ids,
noise_mask=batch["noise_mask"],
)
metrics: dict[str, float] = {
"obfuscation_score": obfuscation_scores.mean().item(),
}
... # compute other batch-level training metrics
if self.logger is not None and self.trainer.is_global_zero:
step = int(
self.current_epoch * self.trainer.num_training_batches
+ batch_idx
)
self.logger.log_metrics(metrics, step=step)
return loss
def on_train_epoch_end(self) -> None:
"""Compute the epoch-level metrics."""
metrics: dict[str, float] = {}
... # compute epoch-level training metrics
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=self.current_epoch)
Validation and Generation Evaluation¶
Validation is performed in two steps, with two differently formatted inputs:
- Calculation of the validation distillation loss on data formatted like the training examples (
input_ids = prompt + labels
). - Generation using both untransformed and transformed embeddings on just the prompt (
input_ids = prompt
).
Because generation is particularly slow, we recommend limiting the number of examples on which you run generation using max_generation_examples
. The TableMetric
serves to synchronize the generation results across all worker processes, and once it reaches max_generation_examples
in length, we proceed with just the first step of validation.
TOKENIZER: Final[transformers.PreTrainedTokenizerBase] = (
transformers.AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
)
TOKENIZER.pad_token = TOKENIZER.eos_token
TOKENIZER.padding_side = "left"
MAX_LENGTH: Final[int] = 4096 # length of input_ids + generated text
class StainedGlassDistillationLightningModule(
StainedGlassDistillationLightningModule
):
def validation_step(
self,
batch: tuple[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
sg_data_collator.TransformLayerTestInputWithAttentionMask[
torch.Tensor
],
],
batch_idx: int,
) -> None:
"""Compute the distillation loss, perform untransformed and transformed
generation, and calculate batch-level metrics.
"""
val_batch, generation_batch = batch
metrics: dict[str, float] = {}
... # do validation and compute batch-level metrics
if len(self.generated_text_table) > self.max_generation_examples:
if self.logger is not None and self.trainer.is_global_zero:
step = int(
self.current_epoch * self.trainer.num_val_batches[0]
+ batch_idx
)
self.logger.log_metrics(metrics, step=step)
return
input_ids = generation_batch["input_ids"]
attention_mask = generation_batch["attention_mask"]
noise_mask = generation_batch["noise_mask"]
labels = generation_batch["labels"]
input_text = TOKENIZER.batch_decode(input_ids, skip_special_tokens=True)
labels_text = TOKENIZER.batch_decode(labels, skip_special_tokens=True)
sequence_length = input_ids.shape[-1]
generation_config = (
huggingface.generation.StainedGlassGenerationConfig.from_tokenizer(
TOKENIZER, max_length=MAX_LENGTH
)
)
with torch.random.fork_rng(
devices=[self.device] if self.device.type != "cpu" else [],
device_type=self.device.type,
):
if self.noisy_model.noise_layer._generators is not None:
torch.manual_seed(self.noisy_model.noise_layer.initial_seed())
generated_ids = self.noisy_model.base_model.generate(
inputs=input_ids,
generation_config=generation_config,
attention_mask=attention_mask,
use_cache=True,
)
generated_text = TOKENIZER.batch_decode(
generated_ids[:, sequence_length:], skip_special_tokens=True
)
generated_ids_from_transformed_embeddings, transformed_embeddings = (
self.noisy_model.generate(
inputs=input_ids,
generation_config=generation_config,
attention_mask=attention_mask,
use_cache=True,
noise_mask=noise_mask,
return_transformed_embeddings=True,
)
)
generated_text_from_transformed_embeddings = TOKENIZER.batch_decode(
generated_ids_from_transformed_embeddings[:, sequence_length:],
skip_special_tokens=True,
)
reconstructed_input_ids = (
self.noisy_model.reconstruct_ids_from_embeddings(
transformed_embeddings
)
)
reconstructed_input_text = TOKENIZER.batch_decode(
reconstructed_input_ids, skip_special_tokens=True
)
percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
input_ids, reconstructed_input_ids, noise_mask
)
obfuscation_scores = [
f"{score.item() * 100:0.1f}%"
for score in percentage_changed_input_ids
]
self.generated_text_table.update(
{
"input_text": input_text,
"reconstructed_input_text": reconstructed_input_text,
"obfuscation_score": obfuscation_scores,
"labels_text": labels_text,
"generated_text": generated_text,
"generated_text_from_transformed_embeddings": generated_text_from_transformed_embeddings,
}
)
if self.logger is not None and self.trainer.is_global_zero:
step = int(
self.current_epoch * self.trainer.num_val_batches[1] + batch_idx
)
self.logger.log_metrics(metrics, step=step)
def on_validation_epoch_end(self) -> None:
"""Calculate epoch-level metrics."""
metrics: dict[str, Any] = {}
generated_text = self.generated_text_table.compute() # noqa: F841
self.generated_text_table.reset()
... # compute epoch-level validation metrics
if self.trainer.is_global_zero and self.logger is not None:
self.logger.log_metrics(metrics, step=self.current_epoch)
Tokenizer Preparation¶
As mentioned in Distillation Training, SGT uses its own tokenization pipeline to build both the noise_mask
and loss_mask
for each dataset element. The input_ids
generated by this pipeline are identical to those produced by transformers.PreTrainedTokenizerBase.apply_chat_template
.
The main elements of this pipeline to be aware of are:
- The
SchemaMapper
: a callable that remaps columns in a Hugging Face dataset to a standardSchema
(a dictionary with the keys"instruction"
"context"
,"response"
,"system_prompt"
). It solves the problem of different datasets calling the same concept different things, like how the tatsu-lab/alpaca dataset calls their system prompttext
, whereas the Open-Orca/OpenOrca dataset calls itsystem_prompt
. - The
TokenizerMapper
: a tokenizer/model-specific callable that is responsible for tokenizing the individual elements of the prompt. Existing mappers for Mistral and Llama are instainedglass_core.huggingface.tokenization_utils.mistral
andstainedglass_core.huggingface.tokenization_utils.llama
. - The
TokenizerWrapper
: a callable that expectsSchema
dictionaries as input, and outputs tensor inputs to the model. Ifinclude_labels
isFalse
,input_ids
is formatted asinput_ids = prompt + labels
for training, and ifinclude_labels
isTrue
,input_ids
is formatted asinput_ids = prompt
for generation.
Using stainedglass_core.utils.functional.sequential
, we can chain the SchemaMapper
and TokenizerWrapper
to form a function that can be applied directly to the desired Hugging Face dataset using datasets.Dataset.map
.
SCHEMA_MAPPER = sg_tokenization_utils.universal.InstructionSchemaMapper(
instruction_key="instruction",
response_key="output",
context_key="input",
system_prompt_key="text",
)
TOKENIZER_MAPPER = (
sg_tokenization_utils.mistral.MistralInstructionTokenizerMapper(TOKENIZER)
)
TRAIN_TOKENIZATION_FN = sg_utils.functional.sequential(
SCHEMA_MAPPER,
sg_tokenization_utils.TokenizerWrapper(
tokenizer=None,
model_type=None,
include_labels=False,
ignore_prompt_loss=True,
tokenizer_mapper=TOKENIZER_MAPPER,
prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION,
),
)
TEST_TOKENIZATION_FN = sg_utils.functional.sequential(
SCHEMA_MAPPER,
sg_tokenization_utils.TokenizerWrapper(
tokenizer=None,
model_type=None,
include_labels=True,
ignore_prompt_loss=True,
tokenizer_mapper=TOKENIZER_MAPPER,
prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION,
),
)
Dataset Preparation¶
There are four main steps to dataset preparation:
- Loading a Hugging Face dataset and splitting it into a training and test dataset (if it doesn't already exist as two splits).
- Tokenizing the dataset using the SGT tokenization pipeline.
- Filtering out examples that are larger than our desired length.
These three steps are be done once by the rank zero process in lightning.pytorch.core.LightningModule.prepare_data
. Caching is highly recommended, especially the larger your dataset is. For instance, using 112 tokenizer workers on a machine with 112 physical cores, a version of Open-Orca/OpenOrca with 4,225,723 examples can take about 40 minutes to tokenize!
Once cached and saved, each worker process will:
- Load and store the datasets as instance variables.
PHYSICAL_CPU_COUNT: Final[int] = psutil.cpu_count(logical=False)
SEED: Final[int] = 42
VALIDATION_SPLIT_RATIO: Final[float | None] = 0.05
LOAD_FROM_DISK: Final[bool] = False
DATASET_NAME: Final[str] = "tatsu-lab/alpaca"
MODEL_CACHE_DIR_NAME: Final[str] = "mistral"
EXPECTED_COLUMNS: set[str] = {
"input_ids",
"attention_mask",
"noise_mask",
"loss_mask",
"labels",
}
def get_cache_dir(dataset: datasets.Dataset) -> str:
"""Get the directory of the first cache file of the loaded dataset."""
if not dataset.cache_files:
raise ValueError("The loaded dataset has no cache files.")
return os.path.dirname(dataset.cache_files[0]["filename"])
def max_length_filter(
sample: (
sg_tokenization_utils.universal.TransformLayerTestMapper.TransformLayerTestInput[
torch.Tensor
]
| sg_tokenization_utils.universal.TransformLayerTrainMapper.TransformLayerTrainInput[
torch.Tensor
]
),
) -> bool:
"""Filter `input_ids` or `input_ids` + `labels` if greater than the max
context length.
"""
if "labels" in sample:
return (
sample["input_ids"].shape[-1] + sample["labels"].shape[-1]
<= MAX_LENGTH
)
return sample["input_ids"].shape[-1] <= MAX_LENGTH
def load_dataset() -> datasets.DatasetDict:
"""Load the dataset and split it into training and validation sets."""
if LOAD_FROM_DISK:
dataset = datasets.load_from_disk(
DATASET_NAME,
)
else:
dataset = datasets.load_dataset(
DATASET_NAME,
num_proc=PHYSICAL_CPU_COUNT,
)
assert not isinstance(
dataset, (datasets.IterableDataset, datasets.IterableDatasetDict)
)
if VALIDATION_SPLIT_RATIO is not None:
if isinstance(dataset, datasets.DatasetDict):
assert dataset.keys() == {"train"}, (
"VALIDATION_SPLIT_RATIO was set, but there are already multiple splits in the loaded dataset."
)
dataset = dataset["train"]
cache_dir = os.path.join(get_cache_dir(dataset), "cache")
os.makedirs(cache_dir, exist_ok=True)
train_indices_cache_file_name = os.path.join(
cache_dir,
f"seed_{SEED}_train_split_{1 - VALIDATION_SPLIT_RATIO:.3f}.cache",
)
test_indices_cache_file_name = os.path.join(
cache_dir,
f"seed_{SEED}_val_split_{VALIDATION_SPLIT_RATIO:.3f}.cache",
)
dataset = dataset.train_test_split(
test_size=VALIDATION_SPLIT_RATIO,
load_from_cache_file=True,
shuffle=True,
generator=np.random.default_rng(seed=SEED),
train_indices_cache_file_name=train_indices_cache_file_name,
test_indices_cache_file_name=test_indices_cache_file_name,
)
assert isinstance(dataset, datasets.DatasetDict)
dataset.set_format("torch")
return dataset
def prepare_dataset_split(
dataset: datasets.Dataset,
split: str,
tokenization_fn: Callable[[Mapping[str, str]], Any],
) -> datasets.Dataset:
"""Prepare a dataset split by tokenizing and filtering the samples."""
cache_dir = os.path.join(
get_cache_dir(dataset), "cache", MODEL_CACHE_DIR_NAME
)
os.makedirs(cache_dir, exist_ok=True)
map_prefix = os.path.join(cache_dir, f"{split}-SGT-tokenized")
filter_prefix = map_prefix + f"-max_length-{MAX_LENGTH}"
dataset = dataset.map(
tokenization_fn,
cache_file_name=map_prefix + ".map",
load_from_cache_file=True,
batched=False,
num_proc=PHYSICAL_CPU_COUNT,
).filter(
max_length_filter,
cache_file_name=filter_prefix + ".filter",
load_from_cache_file=True,
batched=False,
num_proc=PHYSICAL_CPU_COUNT,
)
columns_to_drop = [
column
for column in dataset.column_names
if column not in EXPECTED_COLUMNS
]
return dataset.remove_columns(columns_to_drop)
class StainedGlassDistillationLightningModule(
StainedGlassDistillationLightningModule
):
def prepare_data(self) -> None:
"""Download or load and prepare the dataset splits.
This function is only called on the local rank 0 process, so we won't
attempt to download or map dataset multiple times here.
"""
warnings.filterwarnings(
action="ignore",
module="torch.storage",
category=FutureWarning,
message="You are using `torch.load` with `weights_only=False.*`",
)
dataset = load_dataset()
training_dataset = prepare_dataset_split( # noqa: F841
dataset["train"], "train", TRAIN_TOKENIZATION_FN
)
validation_dataset = prepare_dataset_split( # noqa: F841
dataset["test"], "val", TRAIN_TOKENIZATION_FN
)
generation_dataset = prepare_dataset_split( # noqa: F841
dataset["test"], "generation", TEST_TOKENIZATION_FN
)
def setup(self, stage: str) -> None:
"""Load the prepared datasets as instance attributes."""
dataset = load_dataset()
if stage == "fit":
self.train_dataset = prepare_dataset_split(
dataset["train"], "train", TRAIN_TOKENIZATION_FN
)
self.val_dataset = prepare_dataset_split(
dataset["test"], "val", TRAIN_TOKENIZATION_FN
)
self.generation_dataset = prepare_dataset_split(
dataset["test"], "generation", TEST_TOKENIZATION_FN
)
if stage == "test":
...
if stage == "predict":
...
DataLoader Preparation¶
To load batched data, we use DataCollatorForStainedGlassSeq2Seq
as the collate_fn
to. If you are running on a system with Tensor Cores, you can set pad_to_multiple_of
to 8 to take advantage of them. It is also recommended to set batch_size
to multiples of 8, though for many datasets and on many systems, this will not possible due to memory constraints. You can read more about NVIDIA's recommendations on Satisfying Tensor Core Shape Constraints.
Enabling shuffle
is recommended for training and is discouraged for validation.
The optimal value of num_workers
depends on your batch_size
and dataset. The optimal value does not change with world size, since each process spawns its own set of data loader workers. You can determine the optimal value by timing iteration speed over an empty loop of differently configured data loaders.
Of note, we utilize CombinedLoader
to iterate over both the validation and generation datasets together. The batch items of the CombinedLoader
is in turn the batch items of each component data loader, and enables the two-step validation seen in Validation and Generation Evaluation.
class StainedGlassDistillationLightningModule(
StainedGlassDistillationLightningModule
):
def train_dataloader(
self,
) -> torch.utils.data.DataLoader[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[torch.Tensor]
]:
"""Construct a data loader for the training dataset."""
return torch.utils.data.DataLoader(
self.train_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER,
pad_to_multiple_of=8,
),
batch_size=2,
shuffle=True,
num_workers=4,
pin_memory=True,
)
def val_dataloader(
self,
) -> lightning.pytorch.utilities.combined_loader.CombinedLoader:
"""Construct a data loader to load from both the validation and
generation datasets.
"""
return lightning.pytorch.utilities.combined_loader.CombinedLoader(
(
torch.utils.data.DataLoader(
self.val_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER,
pad_to_multiple_of=8,
),
batch_size=2,
shuffle=False,
num_workers=4,
pin_memory=True,
),
torch.utils.data.DataLoader(
self.generation_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER,
pad_to_multiple_of=8,
),
batch_size=2,
shuffle=False,
num_workers=4,
pin_memory=True,
),
),
mode="max_size",
)
Fitting the Trainer¶
Empirically, we've had success with an effective batch size of ~576, where
We recommend setting precision
to "bf16-true"
. This reduces memory overhead and improves runtime performance, allowing you to utilize higher batch sizes and train more quickly.
Wrap trainer.strategy._precision_plugin
with ReducedPrecisionFilter
if performing reduced precision training
It is required to wrap your trainer's strategy._precision_plugin
with ReducedPrecisionFilter
to prevent automatic downcasting of the SGT by PyTorch Lightning. Reduced precision training of SGT is numerically unstable, and it may fail to converge or even begin training, so SGT has been designed to fail if any of its parameters are in reduced precision.
Setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
can help to prevent CUDA memory errors by allowing the CUDA memory allocator to create expandable segments to better handle varying allocation sizes, which is common given the variable size sequence lengths in text datasets. It is not available on all systems, but it does not hurt to set it either way. Read more about Optimizing memory usage with PYTORCH_CUDA_ALLOC_CONF
.
Using set_float32_matmul_precision
runs float32
matrix multiplications in lower precision. For our purposes, the SGT is the only part of the model containing any float32
matrices, and we have observed considerable speed up using 'high'
as compared to 'highest'
without impact to SGT convergence. We have also observed 'high'
to scale more efficiently as sequence length and batch size increases. Once again, it is not available on all systems, but it doesn't hurt to set it.
If you want to check out any TensorBoard logs produced by your training run, run tensorboard --logdir path/to/your/save/dir
in your console.
All that is left is to do is call fit
on the trainer!
PROJECT: Final[str] = "llm-transform-training-notebook"
SAVE_DIR: Final[str] = "saved/"
LOG_STEP: Final[int] = 100
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_float32_matmul_precision("high")
NUM_DEVICES: Final[int] = len(
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
)
ACCELERATOR: Final[str] = "cuda"
tb_logger = lightning.pytorch.loggers.TensorBoardLogger(
SAVE_DIR,
name=PROJECT,
default_hp_metric=False,
)
trainer = lightning.Trainer(
max_epochs=1,
devices=NUM_DEVICES,
accumulate_grad_batches=36,
accelerator=ACCELERATOR,
strategy="auto",
precision="bf16-true",
log_every_n_steps=LOG_STEP,
default_root_dir=SAVE_DIR,
logger=[tb_logger],
)
if trainer.strategy._precision_plugin is not None:
trainer.strategy._precision_plugin = sg_lightning.ReducedPrecisionFilter(
trainer.strategy._precision_plugin,
)
with trainer.init_module():
distillation_module = StainedGlassDistillationLightningModule(
max_generation_examples=16,
)
trainer.fit(model=distillation_module)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/8 Initializing distributed: GLOBAL_RANK: 6, MEMBER: 7/8 [W814 05:39:45.189627421 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) Initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/8 Initializing distributed: GLOBAL_RANK: 7, MEMBER: 8/8 Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/8 Initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/8 [W814 05:39:46.579779822 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) [W814 05:39:46.711463385 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) [W814 05:39:46.722733285 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) [W814 05:39:46.756057440 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/8 Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/8 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 8 processes ---------------------------------------------------------------------------------------------------- [W814 05:39:46.805460647 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) [W814 05:39:46.812183433 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) [W814 05:39:46.819693574 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator()) Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00, 1.18s/it] Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.35s/it] Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.39s/it] Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.41s/it] Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.43s/it] Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.51s/it] Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.53s/it] LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] | Name | Type | Params | Mode ---------------------------------------------------------------------------------- 0 | generated_text_table | TableMetric | 0 | train 1 | noisy_model | NoiseMaskedNoisyTransformerModel | 3.6 B | train ---------------------------------------------------------------------------------- 469 M Trainable params 3.1 B Non-trainable params 3.6 B Total params 14,269.546Total estimated model params size (MB) 51 Modules in train mode 186 Modules in eval mode
Epoch 0: 100%|██████████| 3088/3088 [08:25<00:00, 6.11it/s, v_num=9]
`Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 0: 100%|██████████| 3088/3088 [08:42<00:00, 5.91it/s, v_num=9]
Next steps¶
After training, you can now prepare the Stained Glass Transform for deployment by reading Deploying Stained Glass Transform (Text).
For a more complete version of this notebook that includes logging, more metrics, generation, and other advanced features helpful for training Stained Glass Transform, see Creating Stained Glass for LLMs Recipe (with logging)