Creating Stained Glass for LLMs Recipe (with logging)¶
This notebook provides a full recipe for creating a Stained Glass Transform for LLMs with additional logging to track the evolution of metrics, parameters, generated text, and obfuscation scores during training. For a simpler introduction to creating a Stained Glass Transform, see the more streamlined tutorial version of this notebook
In [ ]:
Copied!
from __future__ import annotations
import functools
import math
import os
from collections.abc import Mapping, Sized
from typing import Any, Callable, Final, Literal
import datasets
import lightning
import lightning.pytorch.utilities.combined_loader
import numpy as np
import psutil
import torch
import torch.distributed
import torch.utils.data
import torch.utils.tensorboard
import torchmetrics
import torchmetrics.text.rouge
import transformers
import transformers.modeling_outputs
import wandb.wandb_run
from torch import nn
from stainedglass_core import (
huggingface,
loss as sg_loss,
metrics as sg_metrics,
utils as sg_utils,
)
from stainedglass_core.huggingface import (
tokenization_utils as sg_tokenization_utils,
)
from stainedglass_core.huggingface.data import data_collator as sg_data_collator
from stainedglass_core.integrations import (
lightning as sg_lightning,
torchmetrics as sg_torchmetrics,
)
from stainedglass_core.model import noisy_transformer_masking_model
from stainedglass_core.noise_layer import transformer_cloak
from __future__ import annotations
import functools
import math
import os
from collections.abc import Mapping, Sized
from typing import Any, Callable, Final, Literal
import datasets
import lightning
import lightning.pytorch.utilities.combined_loader
import numpy as np
import psutil
import torch
import torch.distributed
import torch.utils.data
import torch.utils.tensorboard
import torchmetrics
import torchmetrics.text.rouge
import transformers
import transformers.modeling_outputs
import wandb.wandb_run
from torch import nn
from stainedglass_core import (
huggingface,
loss as sg_loss,
metrics as sg_metrics,
utils as sg_utils,
)
from stainedglass_core.huggingface import (
tokenization_utils as sg_tokenization_utils,
)
from stainedglass_core.huggingface.data import data_collator as sg_data_collator
from stainedglass_core.integrations import (
lightning as sg_lightning,
torchmetrics as sg_torchmetrics,
)
from stainedglass_core.model import noisy_transformer_masking_model
from stainedglass_core.noise_layer import transformer_cloak
In [ ]:
Copied!
PHYSICAL_CPU_COUNT: Final[int] = psutil.cpu_count(logical=False)
SEED: Final[int] = 42
MODEL_TYPE: Final[type[transformers.PreTrainedModel]] = (
transformers.MistralForCausalLM
)
PRETRAINED_MODEL_NAME_OR_PATH: Final[str] = (
"/models/mistralai/Mistral-7B-Instruct-v0.2"
)
MAX_LENGTH: Final[int] = 4096 # length of input_ids + generated text
TRUNCATED_LAYER_INDEX: Final[int] = 12
MODEL_TOKENIZER_MAPPER_TYPE: Final[
type[sg_tokenization_utils.universal.TokenizerMapper]
] = sg_tokenization_utils.mistral.MistralInstructionTokenizerMapper
VALIDATION_SPLIT_RATIO: Final[float | None] = 0.005
LOAD_FROM_DISK: Final[bool] = False
DATASET_NAME: Final[str] = "tatsu-lab/alpaca"
MODEL_CACHE_DIR_NAME: Final[str] = "mistral"
PHYSICAL_CPU_COUNT: Final[int] = psutil.cpu_count(logical=False)
SEED: Final[int] = 42
MODEL_TYPE: Final[type[transformers.PreTrainedModel]] = (
transformers.MistralForCausalLM
)
PRETRAINED_MODEL_NAME_OR_PATH: Final[str] = (
"/models/mistralai/Mistral-7B-Instruct-v0.2"
)
MAX_LENGTH: Final[int] = 4096 # length of input_ids + generated text
TRUNCATED_LAYER_INDEX: Final[int] = 12
MODEL_TOKENIZER_MAPPER_TYPE: Final[
type[sg_tokenization_utils.universal.TokenizerMapper]
] = sg_tokenization_utils.mistral.MistralInstructionTokenizerMapper
VALIDATION_SPLIT_RATIO: Final[float | None] = 0.005
LOAD_FROM_DISK: Final[bool] = False
DATASET_NAME: Final[str] = "tatsu-lab/alpaca"
MODEL_CACHE_DIR_NAME: Final[str] = "mistral"
In [ ]:
Copied!
SCHEMA_MAPPER: Final[
Callable[
[Mapping[str, str]],
sg_tokenization_utils.universal.InstructionSchemaMapper.Schema,
]
] = sg_tokenization_utils.universal.InstructionSchemaMapper(
instruction_key="instruction",
response_key="output",
context_key="input",
system_prompt_key="text",
)
EXPECTED_COLUMNS: set[str] = {
"input_ids",
"attention_mask",
"noise_mask",
"loss_mask",
"labels",
}
TOKENIZER: Final[transformers.PreTrainedTokenizerBase] = (
transformers.AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
)
TOKENIZER.pad_token = TOKENIZER.eos_token
TOKENIZER.padding_side = "left"
TOKENIZER_MAPPER: Final[sg_tokenization_utils.universal.TokenizerMapper] = (
MODEL_TOKENIZER_MAPPER_TYPE(TOKENIZER)
)
TRAIN_TOKENIZATION_FN = sg_utils.functional.sequential(
SCHEMA_MAPPER,
sg_tokenization_utils.TokenizerWrapper(
tokenizer=None,
model_type=None,
include_labels=False,
ignore_prompt_loss=True,
tokenizer_mapper=TOKENIZER_MAPPER,
prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION,
),
)
TEST_TOKENIZATION_FN = sg_utils.functional.sequential(
SCHEMA_MAPPER,
sg_tokenization_utils.TokenizerWrapper(
tokenizer=None,
model_type=None,
include_labels=True,
ignore_prompt_loss=True,
tokenizer_mapper=TOKENIZER_MAPPER,
prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION,
),
)
def get_cache_dir(dataset: datasets.Dataset) -> str:
"""Get the directory of the first cache file of the loaded dataset."""
if not dataset.cache_files:
raise ValueError("The loaded dataset has no cache files.")
return os.path.dirname(dataset.cache_files[0]["filename"])
def max_length_filter(
sample: (
sg_tokenization_utils.universal.TransformLayerTestMapper.TransformLayerTestInput[
torch.Tensor
]
| sg_tokenization_utils.universal.TransformLayerTrainMapper.TransformLayerTrainInput[
torch.Tensor
]
),
) -> bool:
"""Filter `input_ids` or `input_ids` + `labels` greater than the max length."""
if "labels" in sample:
return (
sample["input_ids"].shape[-1] + sample["labels"].shape[-1]
<= MAX_LENGTH
)
return sample["input_ids"].shape[-1] <= MAX_LENGTH
def load_dataset() -> datasets.DatasetDict:
"""Load the dataset and split it into training and validation sets."""
if LOAD_FROM_DISK:
dataset = datasets.load_from_disk(
DATASET_NAME,
)
else:
dataset = datasets.load_dataset(
DATASET_NAME,
num_proc=PHYSICAL_CPU_COUNT,
)
assert not isinstance(
dataset, (datasets.IterableDataset, datasets.IterableDatasetDict)
)
if VALIDATION_SPLIT_RATIO is not None:
if isinstance(dataset, datasets.DatasetDict):
assert dataset.keys() == {"train"}, (
"VALIDATION_SPLIT_RATIO was set, but there are already multiple splits in the loaded dataset."
)
dataset = dataset["train"]
cache_dir = os.path.join(get_cache_dir(dataset), "cache")
os.makedirs(
cache_dir, exist_ok=True
) # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
train_indices_cache_file_name = os.path.join(
cache_dir,
f"seed_{SEED}_train_split_{1 - VALIDATION_SPLIT_RATIO:.3f}.cache",
)
test_indices_cache_file_name = os.path.join(
cache_dir,
f"seed_{SEED}_val_split_{VALIDATION_SPLIT_RATIO:.3f}.cache",
)
dataset = dataset.train_test_split(
test_size=VALIDATION_SPLIT_RATIO,
load_from_cache_file=True,
shuffle=True,
generator=np.random.default_rng(seed=SEED),
train_indices_cache_file_name=train_indices_cache_file_name,
test_indices_cache_file_name=test_indices_cache_file_name,
)
dataset.set_format("torch")
return dataset
def prepare_dataset_split(
dataset: datasets.Dataset,
split: str,
tokenization_fn: Callable[[Mapping[str, str]], Any],
) -> datasets.Dataset:
"""Prepare a dataset split by tokenizing and filtering the samples."""
cache_dir = os.path.join(
get_cache_dir(dataset), "cache", MODEL_CACHE_DIR_NAME
)
os.makedirs(
cache_dir, exist_ok=True
) # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
map_prefix = os.path.join(cache_dir, f"{split}-SGT-tokenized")
filter_prefix = map_prefix + f"-max_length-{MAX_LENGTH}"
dataset = dataset.map(
tokenization_fn,
cache_file_name=map_prefix + ".map",
load_from_cache_file=True,
batched=False,
num_proc=PHYSICAL_CPU_COUNT,
).filter(
max_length_filter,
cache_file_name=filter_prefix + ".filter",
load_from_cache_file=True,
batched=False,
num_proc=PHYSICAL_CPU_COUNT,
)
columns_to_drop = [
column
for column in dataset.column_names
if column not in EXPECTED_COLUMNS
]
return dataset.remove_columns(columns_to_drop)
class TrainingStepMetrics(nn.Module):
"""Common training metrics to be collected during [`StainedGlassDistillationLightningModule._training_step`][] in both training and
validation.
"""
def __init__(self) -> None:
super().__init__()
self.mean_losses = nn.ModuleDict(
{
"distillation_layer_cosine_distance_loss": torchmetrics.MeanMetric(),
"distillation_layer_l2_distance_loss": torchmetrics.MeanMetric(),
"normalized_input_embedding_cosine_similarity_loss": torchmetrics.MeanMetric(),
"std_log_ratio_loss": torchmetrics.MeanMetric(),
"composite_loss": torchmetrics.MeanMetric(),
}
)
self.perplexity = torchmetrics.text.Perplexity(
ignore_index=TOKENIZER.pad_token_id
)
self.obfuscation_scores = nn.ModuleDict(
{
"mean": torchmetrics.MeanMetric(),
"min": torchmetrics.MinMetric(),
"max": torchmetrics.MaxMetric(),
}
)
class ValidationStepMetrics(TrainingStepMetrics):
"""Additional metrics only to be collected during [`StainedGlassDistillationLightningModule.validation_step`][]."""
def __init__(self) -> None:
super().__init__()
self.rouge = torchmetrics.text.rouge.ROUGEScore(
rouge_keys=("rouge1", "rouge2", "rougeL")
)
self.obfuscation_quantiles = sg_torchmetrics.QuantileMetric(
q=torch.linspace(0.1, 0.9, steps=9)
)
self.obfuscation_scores_cat = torchmetrics.CatMetric()
self.generated_text_table = sg_torchmetrics.TableMetric()
class StainedGlassDistillationLightningModule(lightning.LightningModule):
"""[`lightning.LightningModule`][] for training an LLM Stained Glass Transform via distillation."""
def __init__(
self,
max_generation_examples: int = 10,
obfuscation_log_step: int | None = None,
noise_component_histogram_log_step: int | None = None,
) -> None:
"""Initialize a `StainedGlassDistillationLightningModule`.
Args:
max_generation_examples: The cutoff for the number of validation examples for which to run generation. The table will have the
nearest multiple of the batch size over this number of rows.
obfuscation_log_step: How often to check the obfuscation score.
noise_component_histogram_log_step: How often to log the noise component histograms (means and standard deviations) during
training.
"""
super().__init__()
self.max_generation_examples = max_generation_examples
self.obfuscation_log_step = obfuscation_log_step
self.noise_component_histogram_log_step = (
noise_component_histogram_log_step
)
self.save_hyperparameters()
self.base_model_generation_cache: dict[int, list[str]] = {}
"""A mapping of `batch_idx` to first-epoch base model generation results that we can use to speed up the generation step of
validation if the generation dataloader is not shuffled. Alternatively, we could do this as a dataset pre-processing step to get
more accurate labels for training.
"""
with sg_utils.torch.dtypes.default_dtype(torch.float32):
self._train_metrics = TrainingStepMetrics()
self._val_metrics = ValidationStepMetrics()
self.train_val_metrics: dict[str, TrainingStepMetrics] = {
"train": self._train_metrics,
"valid": self._val_metrics,
}
@functools.cached_property
def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None: # pyright: ignore[reportPrivateImportUsage]
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
return logger.experiment
return None
@functools.cached_property
def wandb_logger(self) -> lightning.pytorch.loggers.WandbLogger | None:
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
return logger
return None
@functools.cached_property
def wandb_run(self) -> wandb.wandb_run.Run | None:
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
return logger.experiment
return None
def configure_model(self) -> None:
"""Configure the models and loss function."""
base_model = transformers.MistralForCausalLM.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
assert isinstance(base_model, transformers.MistralForCausalLM)
for param in base_model.parameters():
param.requires_grad = False
base_model = base_model.eval()
self.noisy_model = (
noisy_transformer_masking_model.NoiseMaskedNoisyTransformerModel(
noise_layer_class=transformer_cloak.TransformerCloak,
base_model=base_model,
target_layer="model.embed_tokens",
truncated_layer_index=TRUNCATED_LAYER_INDEX,
scale=(1e-8, 1.0),
shallow=1.0,
mean_dropout=0.1,
std_dropout=0.1,
config_path=PRETRAINED_MODEL_NAME_OR_PATH,
use_causal_mask=True,
transformer_type=transformers.MistralModel,
directly_learn_stds=True,
rho_init=0.0,
seed=SEED,
)
)
self.distillation_loss, self.get_losses, self.get_hyperparameters = (
sg_loss.distillation.distillation_loss_factory(
self.noisy_model,
distillation_layer_index=TRUNCATED_LAYER_INDEX,
alpha=0.54,
std_log_ratio_loss_weight=0.01,
input_embedding_similarity_loss_weight=0.75,
distillation_layer_cosine_distance_loss_weight=12.0,
)
)
self.get_transformed_embeddings = (
self.noisy_model.noise_layer.get_transformed_output_factory()
)
self.get_applied_transform_components = self.noisy_model.noise_layer.get_applied_transform_components_factory()
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure the model optimizer."""
no_weight_decay_param_group = {
"params": (
[
param
for param in self.noisy_model.noise_layer.std_estimator.module.linear.parameters()
if param.requires_grad
]
+ [
param
for param in self.noisy_model.noise_layer.mean_estimator.module.linear.parameters()
if param.requires_grad
]
),
"weight_decay": 0.0,
}
no_weight_decay_params = set(no_weight_decay_param_group["params"])
default_param_group = {
"params": [
param
for param in self.noisy_model.noise_layer.parameters()
if param.requires_grad and param not in no_weight_decay_params
],
}
return torch.optim.AdamW(
params=[no_weight_decay_param_group, default_param_group],
lr=3e-5,
amsgrad=False,
betas=(0.9, 0.95),
eps=1e-5,
weight_decay=0.1,
)
def prepare_data(self) -> None:
"""Download and prepare the dataset splits.
This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
"""
dataset = load_dataset()
training_dataset = prepare_dataset_split( # noqa: F841
dataset["train"], "train", TRAIN_TOKENIZATION_FN
)
validation_dataset = prepare_dataset_split( # noqa: F841
dataset["test"], "val", TRAIN_TOKENIZATION_FN
)
generation_dataset = prepare_dataset_split( # noqa: F841
dataset["test"], "generation", TEST_TOKENIZATION_FN
)
def setup(self, stage: str) -> None:
"""Set up the datasets."""
dataset = load_dataset()
if stage == "fit":
self.train_dataset = prepare_dataset_split(
dataset["train"], "train", TRAIN_TOKENIZATION_FN
)
self.val_dataset = prepare_dataset_split(
dataset["test"], "val", TRAIN_TOKENIZATION_FN
)
self.generation_dataset = prepare_dataset_split(
dataset["test"], "generation", TEST_TOKENIZATION_FN
)
if stage == "test":
...
if stage == "predict":
...
def train_dataloader(
self,
) -> torch.utils.data.DataLoader[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[torch.Tensor]
]:
return torch.utils.data.DataLoader(
self.train_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER, pad_to_multiple_of=8
),
batch_size=2,
shuffle=True,
num_workers=4,
pin_memory=True,
)
def val_dataloader(
self,
) -> lightning.pytorch.utilities.combined_loader.CombinedLoader:
return lightning.pytorch.utilities.combined_loader.CombinedLoader(
(
torch.utils.data.DataLoader(
self.val_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER, pad_to_multiple_of=8
),
batch_size=2,
shuffle=False,
num_workers=4,
pin_memory=True,
),
torch.utils.data.DataLoader(
self.generation_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER, pad_to_multiple_of=8
),
batch_size=2,
shuffle=False,
num_workers=4,
pin_memory=True,
),
),
mode="max_size",
)
@functools.cached_property
def is_generation_dataloader_shuffled(self) -> bool:
"""Whether the generation dataloader is shuffled.
Used to determine if we can cache and reuse base model generations from the first validation epoch.
"""
val_dataloaders = self.trainer.val_dataloaders
assert isinstance(val_dataloaders, tuple)
_, generation_dataloader = val_dataloaders
assert isinstance(generation_dataloader, torch.utils.data.DataLoader)
assert isinstance(
generation_dataloader.sampler, torch.utils.data.Sampler
)
return not (
isinstance(
generation_dataloader.sampler,
torch.utils.data.SequentialSampler,
)
or (
isinstance(
generation_dataloader.sampler,
torch.utils.data.DistributedSampler,
)
and generation_dataloader.sampler.shuffle
)
)
def forward(
self, **kwargs: Any
) -> transformers.modeling_outputs.CausalLMOutputWithPast:
with self.noisy_model.distillation_context():
return self.noisy_model(**kwargs)
def on_train_epoch_start(self) -> None:
"""Ensure that the model is truncated prior to starting training."""
self.noisy_model.truncate_and_offload()
def _training_step(
self,
batch: sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
batch_idx: int,
dataloader: torch.utils.data.DataLoader[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
]
],
num_batches: float,
mode: Literal["train", "valid"],
metrics: dict[str, float],
) -> torch.Tensor:
input_ids = batch["input_ids"]
noise_mask = batch["noise_mask"]
output = self(
input_ids=input_ids,
attention_mask=batch["attention_mask"],
use_cache=True,
noise_mask=noise_mask,
)
loss = self.distillation_loss(batch["loss_mask"])
batch_size = dataloader.batch_size
assert batch_size is not None
dataset = dataloader.dataset
assert isinstance(dataset, Sized)
num_examples_per_epoch = (num_batches - 1) * batch_size + (
len(dataset) % batch_size if dataloader.drop_last else batch_size
)
dataset_size = (
len(dataset) - len(dataset) % batch_size
if dataloader.drop_last
else len(dataset)
)
current_batch_size = (
len(dataset) % batch_size
if (batch_idx + 1) == len(dataloader) and not dataloader.drop_last
else batch_size
)
num_examples = self.trainer.world_size * (
self.current_epoch * num_examples_per_epoch
+ batch_idx * batch_size
+ current_batch_size
)
# TODO: handle intra-epoch validation
metrics[f"num_examples/{mode}/batch"] = (
num_examples # allows for head-to-head comparisons between datasets
)
metrics[f"percent_dataset/{mode}/batch"] = (
num_examples / dataset_size
) # allows for head-to-head-comparisons on the same dataset with different fractional sizes (e.g. limit_train_batches)
metrics[f"epoch/{mode}/batch"] = (
self.current_epoch + (batch_idx + 1) / num_batches
) # tracks our progress through the currently configured dataloader
obfuscation_log_step = self.obfuscation_log_step or int(
math.sqrt(num_batches)
)
if (
batch_idx % obfuscation_log_step == 0
or batch_idx == num_batches - 1
):
transformed_embeddings = self.get_transformed_embeddings()
reconstructed_input_ids = (
self.noisy_model.reconstruct_ids_from_embeddings(
transformed_embeddings
)
)
percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
input_ids, reconstructed_input_ids, noise_mask
)
if mode == "valid":
self._val_metrics.obfuscation_scores_cat.update(
percentage_changed_input_ids
)
self._val_metrics.obfuscation_quantiles.update(
percentage_changed_input_ids
)
metrics[f"obfuscation/{mode}/batch"] = (
percentage_changed_input_ids.mean().item()
)
for metric in self.train_val_metrics[
mode
].obfuscation_scores.values():
metric.update(percentage_changed_input_ids)
losses = self.get_losses()
for loss_name, loss in losses.items():
self.train_val_metrics[mode].mean_losses[loss_name].update(loss)
metrics[f"{loss_name}/{mode}/batch"] = loss.item()
if not self.noisy_model.is_truncated_and_offloaded:
logits = output["logits"][..., :-1, :].contiguous()
labels = input_ids[..., 1:].contiguous()
metrics[f"perplexity/{mode}/batch"] = (
self.train_val_metrics[mode].perplexity(logits, labels).item()
)
applied_transform_components = self.get_applied_transform_components()
noise_component_histogram_log_step = (
self.noise_component_histogram_log_step
or int(math.sqrt(num_batches))
)
if (
self.tb_writer is not None
and self.trainer.is_global_zero
and (
batch_idx % noise_component_histogram_log_step == 0
or batch_idx == num_batches - 1
)
):
for name, values in applied_transform_components.items():
self.tb_writer.add_histogram(
f"{name}_histogram/{mode}/batch",
values=values,
global_step=int(
self.current_epoch * num_batches + batch_idx
),
bins=512, # pyright: ignore[reportArgumentType]
)
return loss
def training_step(
self,
batch: sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
batch_idx: int,
) -> torch.Tensor:
"""Compute the training distillation loss."""
train_dataloader = self.trainer.train_dataloader
assert train_dataloader is not None
num_batches = self.trainer.num_training_batches
mode = "train"
step = int(self.current_epoch * num_batches + batch_idx)
metrics: dict[str, float] = {}
loss = self._training_step(
batch,
batch_idx,
train_dataloader,
num_batches=num_batches,
mode=mode,
metrics=metrics,
)
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=step)
return loss
def _on_train_epoch_end(
self,
mode: Literal["train", "valid"],
metrics: dict[str, float],
) -> None:
"""Compute epoch-level training metrics."""
metrics[f"epoch/{mode}/epoch"] = self.current_epoch
for loss_name, metric in self.train_val_metrics[
mode
].mean_losses.items():
metrics[f"{loss_name}/{mode}/epoch"] = metric.compute().item()
metric.reset()
for metric_name, metric in self.train_val_metrics[
mode
].obfuscation_scores.items():
metrics[f"{metric_name}_obfuscation/{mode}/epoch"] = (
metric.compute().item()
)
metric.reset()
if self.train_val_metrics[mode].perplexity.update_called:
metrics[f"perplexity/{mode}/epoch"] = (
self.train_val_metrics[mode].perplexity.compute().item()
)
self.train_val_metrics[mode].perplexity.reset()
def on_train_epoch_end(self) -> None:
"""Compute and log the epoch-level training metrics."""
metrics: dict[str, Any] = {}
self._on_train_epoch_end(mode="train", metrics=metrics)
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=self.current_epoch)
def on_validation_epoch_start(self) -> None:
"""Ensure that the model is fully-loaded prior to starting validation."""
self.noisy_model.restore_and_load()
def validation_step(
self,
batch: tuple[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
sg_data_collator.TransformLayerTestInputWithAttentionMask[
torch.Tensor
],
],
batch_idx: int,
) -> None:
"""Compute the validation distillation loss and perform generation."""
val_batch, generation_batch = batch
mode = "valid"
assert self.trainer.val_dataloaders is not None
val_dataloader = self.trainer.val_dataloaders[0]
num_batches = self.trainer.num_val_batches[0]
step = step = int(self.current_epoch * num_batches + batch_idx)
metrics: dict[str, float] = {}
_ = self._training_step(
val_batch,
batch_idx,
val_dataloader,
num_batches=num_batches,
mode=mode,
metrics=metrics,
)
if (
len(self._val_metrics.generated_text_table)
> self.max_generation_examples
):
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=step)
return
input_ids = generation_batch["input_ids"]
attention_mask = generation_batch["attention_mask"]
noise_mask = generation_batch["noise_mask"]
labels = generation_batch["labels"]
input_text = TOKENIZER.batch_decode(input_ids, skip_special_tokens=True)
labels_text = TOKENIZER.batch_decode(labels, skip_special_tokens=True)
sequence_length = input_ids.shape[-1]
generation_config = (
huggingface.generation.StainedGlassGenerationConfig.from_tokenizer(
TOKENIZER, max_length=MAX_LENGTH
)
)
if (
not self.is_generation_dataloader_shuffled
and batch_idx in self.base_model_generation_cache
):
generated_text = self.base_model_generation_cache[batch_idx]
else:
with torch.random.fork_rng(
devices=[self.device] if self.device.type != "cpu" else [],
device_type=self.device.type,
):
if self.noisy_model.noise_layer._generators is not None:
torch.manual_seed(
self.noisy_model.noise_layer.initial_seed()
)
generated_ids = self.noisy_model.base_model.generate(
inputs=input_ids,
generation_config=generation_config,
attention_mask=attention_mask,
use_cache=True,
)
generated_text = TOKENIZER.batch_decode(
generated_ids[:, sequence_length:], skip_special_tokens=True
)
self.base_model_generation_cache[batch_idx] = generated_text
generated_ids_from_transformed_embeddings, transformed_embeddings = (
self.noisy_model.generate(
inputs=input_ids,
generation_config=generation_config,
attention_mask=attention_mask,
use_cache=True,
noise_mask=noise_mask,
return_transformed_embeddings=True,
)
)
generated_text_from_transformed_embeddings = TOKENIZER.batch_decode(
generated_ids_from_transformed_embeddings[:, sequence_length:],
skip_special_tokens=True,
)
reconstructed_input_ids = (
self.noisy_model.reconstruct_ids_from_embeddings(
transformed_embeddings
)
)
reconstructed_input_text = TOKENIZER.batch_decode(
reconstructed_input_ids, skip_special_tokens=True
)
percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
input_ids, reconstructed_input_ids, noise_mask
)
rouge_scores: dict[str, torch.Tensor] = self._val_metrics.rouge(
generated_text_from_transformed_embeddings, generated_text
)
metrics.update(
{
f"{metric_name}/{mode}/batch": value.item()
for metric_name, value in rouge_scores.items()
if metric_name.endswith("_fmeasure")
}
)
self._val_metrics.generated_text_table.update(
{
"input_text": input_text,
"reconstructed_input_text": reconstructed_input_text,
"obfuscation_score": [
f"{score.item() * 100:0.1f}%"
for score in percentage_changed_input_ids
],
"labels_text": labels_text,
"generated_text": generated_text,
"generated_text_from_transformed_embeddings": generated_text_from_transformed_embeddings,
}
)
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=step)
def on_validation_epoch_end(self) -> None:
"""Compute and log the epoch-level validation metrics."""
step = self.current_epoch
mode = "valid"
metrics: dict[str, Any] = {}
self._on_train_epoch_end(mode=mode, metrics=metrics)
observed_obfuscation_quantiles = (
self._val_metrics.obfuscation_quantiles.compute()
)
metrics.update(
{
f"obfuscation_quantile/{q * 100:.0f}%/{mode}/epoch": value
for q, value in observed_obfuscation_quantiles.items()
}
)
self._val_metrics.obfuscation_quantiles.reset()
observed_obfuscation_scores = (
self._val_metrics.obfuscation_scores_cat.compute()
)
self._val_metrics.obfuscation_scores_cat.reset()
generated_text_table = self._val_metrics.generated_text_table.compute()
self._val_metrics.generated_text_table.reset()
if self.trainer.is_global_zero:
if self.tb_writer is not None:
self.tb_writer.add_histogram(
f"obfuscation_histogram/{mode}/epoch",
values=observed_obfuscation_scores.to("cpu", torch.float32),
global_step=step,
bins=torch.linspace(0.0, 1.0, 513), # pyright: ignore[reportArgumentType]
)
self.tb_writer.add_text(
f"generated_text/{mode}/epoch",
text_string=sg_utils.torch.tensorboard.to_markdown_table(
generated_text_table
),
global_step=step,
)
if self.wandb_logger is not None:
self.wandb_logger.log_text(
f"generated_text/{self.current_epoch}/{mode}/epoch",
columns=list(generated_text_table.keys()),
data=list(zip(*generated_text_table.values())),
)
hyperparameters = self.get_hyperparameters()
if self.wandb_run is not None:
self.wandb_run.config.update(hyperparameters)
if self.logger is not None:
self.logger.log_metrics(metrics, step=step)
# TODO: log_hyperparams doesn't currently accept a step and so the valid/epoch metrics aren't logged correctly
# Fixed by: https://github.com/Lightning-AI/pytorch-lightning/pull/20176
# self.logger.log_hyperparams(hyperparameters, metrics=metrics, step=step)
SCHEMA_MAPPER: Final[
Callable[
[Mapping[str, str]],
sg_tokenization_utils.universal.InstructionSchemaMapper.Schema,
]
] = sg_tokenization_utils.universal.InstructionSchemaMapper(
instruction_key="instruction",
response_key="output",
context_key="input",
system_prompt_key="text",
)
EXPECTED_COLUMNS: set[str] = {
"input_ids",
"attention_mask",
"noise_mask",
"loss_mask",
"labels",
}
TOKENIZER: Final[transformers.PreTrainedTokenizerBase] = (
transformers.AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
)
TOKENIZER.pad_token = TOKENIZER.eos_token
TOKENIZER.padding_side = "left"
TOKENIZER_MAPPER: Final[sg_tokenization_utils.universal.TokenizerMapper] = (
MODEL_TOKENIZER_MAPPER_TYPE(TOKENIZER)
)
TRAIN_TOKENIZATION_FN = sg_utils.functional.sequential(
SCHEMA_MAPPER,
sg_tokenization_utils.TokenizerWrapper(
tokenizer=None,
model_type=None,
include_labels=False,
ignore_prompt_loss=True,
tokenizer_mapper=TOKENIZER_MAPPER,
prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION,
),
)
TEST_TOKENIZATION_FN = sg_utils.functional.sequential(
SCHEMA_MAPPER,
sg_tokenization_utils.TokenizerWrapper(
tokenizer=None,
model_type=None,
include_labels=True,
ignore_prompt_loss=True,
tokenizer_mapper=TOKENIZER_MAPPER,
prompt_type=sg_tokenization_utils.PromptType.INSTRUCTION,
),
)
def get_cache_dir(dataset: datasets.Dataset) -> str:
"""Get the directory of the first cache file of the loaded dataset."""
if not dataset.cache_files:
raise ValueError("The loaded dataset has no cache files.")
return os.path.dirname(dataset.cache_files[0]["filename"])
def max_length_filter(
sample: (
sg_tokenization_utils.universal.TransformLayerTestMapper.TransformLayerTestInput[
torch.Tensor
]
| sg_tokenization_utils.universal.TransformLayerTrainMapper.TransformLayerTrainInput[
torch.Tensor
]
),
) -> bool:
"""Filter `input_ids` or `input_ids` + `labels` greater than the max length."""
if "labels" in sample:
return (
sample["input_ids"].shape[-1] + sample["labels"].shape[-1]
<= MAX_LENGTH
)
return sample["input_ids"].shape[-1] <= MAX_LENGTH
def load_dataset() -> datasets.DatasetDict:
"""Load the dataset and split it into training and validation sets."""
if LOAD_FROM_DISK:
dataset = datasets.load_from_disk(
DATASET_NAME,
)
else:
dataset = datasets.load_dataset(
DATASET_NAME,
num_proc=PHYSICAL_CPU_COUNT,
)
assert not isinstance(
dataset, (datasets.IterableDataset, datasets.IterableDatasetDict)
)
if VALIDATION_SPLIT_RATIO is not None:
if isinstance(dataset, datasets.DatasetDict):
assert dataset.keys() == {"train"}, (
"VALIDATION_SPLIT_RATIO was set, but there are already multiple splits in the loaded dataset."
)
dataset = dataset["train"]
cache_dir = os.path.join(get_cache_dir(dataset), "cache")
os.makedirs(
cache_dir, exist_ok=True
) # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
train_indices_cache_file_name = os.path.join(
cache_dir,
f"seed_{SEED}_train_split_{1 - VALIDATION_SPLIT_RATIO:.3f}.cache",
)
test_indices_cache_file_name = os.path.join(
cache_dir,
f"seed_{SEED}_val_split_{VALIDATION_SPLIT_RATIO:.3f}.cache",
)
dataset = dataset.train_test_split(
test_size=VALIDATION_SPLIT_RATIO,
load_from_cache_file=True,
shuffle=True,
generator=np.random.default_rng(seed=SEED),
train_indices_cache_file_name=train_indices_cache_file_name,
test_indices_cache_file_name=test_indices_cache_file_name,
)
dataset.set_format("torch")
return dataset
def prepare_dataset_split(
dataset: datasets.Dataset,
split: str,
tokenization_fn: Callable[[Mapping[str, str]], Any],
) -> datasets.Dataset:
"""Prepare a dataset split by tokenizing and filtering the samples."""
cache_dir = os.path.join(
get_cache_dir(dataset), "cache", MODEL_CACHE_DIR_NAME
)
os.makedirs(
cache_dir, exist_ok=True
) # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
map_prefix = os.path.join(cache_dir, f"{split}-SGT-tokenized")
filter_prefix = map_prefix + f"-max_length-{MAX_LENGTH}"
dataset = dataset.map(
tokenization_fn,
cache_file_name=map_prefix + ".map",
load_from_cache_file=True,
batched=False,
num_proc=PHYSICAL_CPU_COUNT,
).filter(
max_length_filter,
cache_file_name=filter_prefix + ".filter",
load_from_cache_file=True,
batched=False,
num_proc=PHYSICAL_CPU_COUNT,
)
columns_to_drop = [
column
for column in dataset.column_names
if column not in EXPECTED_COLUMNS
]
return dataset.remove_columns(columns_to_drop)
class TrainingStepMetrics(nn.Module):
"""Common training metrics to be collected during [`StainedGlassDistillationLightningModule._training_step`][] in both training and
validation.
"""
def __init__(self) -> None:
super().__init__()
self.mean_losses = nn.ModuleDict(
{
"distillation_layer_cosine_distance_loss": torchmetrics.MeanMetric(),
"distillation_layer_l2_distance_loss": torchmetrics.MeanMetric(),
"normalized_input_embedding_cosine_similarity_loss": torchmetrics.MeanMetric(),
"std_log_ratio_loss": torchmetrics.MeanMetric(),
"composite_loss": torchmetrics.MeanMetric(),
}
)
self.perplexity = torchmetrics.text.Perplexity(
ignore_index=TOKENIZER.pad_token_id
)
self.obfuscation_scores = nn.ModuleDict(
{
"mean": torchmetrics.MeanMetric(),
"min": torchmetrics.MinMetric(),
"max": torchmetrics.MaxMetric(),
}
)
class ValidationStepMetrics(TrainingStepMetrics):
"""Additional metrics only to be collected during [`StainedGlassDistillationLightningModule.validation_step`][]."""
def __init__(self) -> None:
super().__init__()
self.rouge = torchmetrics.text.rouge.ROUGEScore(
rouge_keys=("rouge1", "rouge2", "rougeL")
)
self.obfuscation_quantiles = sg_torchmetrics.QuantileMetric(
q=torch.linspace(0.1, 0.9, steps=9)
)
self.obfuscation_scores_cat = torchmetrics.CatMetric()
self.generated_text_table = sg_torchmetrics.TableMetric()
class StainedGlassDistillationLightningModule(lightning.LightningModule):
"""[`lightning.LightningModule`][] for training an LLM Stained Glass Transform via distillation."""
def __init__(
self,
max_generation_examples: int = 10,
obfuscation_log_step: int | None = None,
noise_component_histogram_log_step: int | None = None,
) -> None:
"""Initialize a `StainedGlassDistillationLightningModule`.
Args:
max_generation_examples: The cutoff for the number of validation examples for which to run generation. The table will have the
nearest multiple of the batch size over this number of rows.
obfuscation_log_step: How often to check the obfuscation score.
noise_component_histogram_log_step: How often to log the noise component histograms (means and standard deviations) during
training.
"""
super().__init__()
self.max_generation_examples = max_generation_examples
self.obfuscation_log_step = obfuscation_log_step
self.noise_component_histogram_log_step = (
noise_component_histogram_log_step
)
self.save_hyperparameters()
self.base_model_generation_cache: dict[int, list[str]] = {}
"""A mapping of `batch_idx` to first-epoch base model generation results that we can use to speed up the generation step of
validation if the generation dataloader is not shuffled. Alternatively, we could do this as a dataset pre-processing step to get
more accurate labels for training.
"""
with sg_utils.torch.dtypes.default_dtype(torch.float32):
self._train_metrics = TrainingStepMetrics()
self._val_metrics = ValidationStepMetrics()
self.train_val_metrics: dict[str, TrainingStepMetrics] = {
"train": self._train_metrics,
"valid": self._val_metrics,
}
@functools.cached_property
def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None: # pyright: ignore[reportPrivateImportUsage]
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
return logger.experiment
return None
@functools.cached_property
def wandb_logger(self) -> lightning.pytorch.loggers.WandbLogger | None:
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
return logger
return None
@functools.cached_property
def wandb_run(self) -> wandb.wandb_run.Run | None:
assert self.loggers is not None
for logger in self.loggers:
if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
return logger.experiment
return None
def configure_model(self) -> None:
"""Configure the models and loss function."""
base_model = transformers.MistralForCausalLM.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
assert isinstance(base_model, transformers.MistralForCausalLM)
for param in base_model.parameters():
param.requires_grad = False
base_model = base_model.eval()
self.noisy_model = (
noisy_transformer_masking_model.NoiseMaskedNoisyTransformerModel(
noise_layer_class=transformer_cloak.TransformerCloak,
base_model=base_model,
target_layer="model.embed_tokens",
truncated_layer_index=TRUNCATED_LAYER_INDEX,
scale=(1e-8, 1.0),
shallow=1.0,
mean_dropout=0.1,
std_dropout=0.1,
config_path=PRETRAINED_MODEL_NAME_OR_PATH,
use_causal_mask=True,
transformer_type=transformers.MistralModel,
directly_learn_stds=True,
rho_init=0.0,
seed=SEED,
)
)
self.distillation_loss, self.get_losses, self.get_hyperparameters = (
sg_loss.distillation.distillation_loss_factory(
self.noisy_model,
distillation_layer_index=TRUNCATED_LAYER_INDEX,
alpha=0.54,
std_log_ratio_loss_weight=0.01,
input_embedding_similarity_loss_weight=0.75,
distillation_layer_cosine_distance_loss_weight=12.0,
)
)
self.get_transformed_embeddings = (
self.noisy_model.noise_layer.get_transformed_output_factory()
)
self.get_applied_transform_components = self.noisy_model.noise_layer.get_applied_transform_components_factory()
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure the model optimizer."""
no_weight_decay_param_group = {
"params": (
[
param
for param in self.noisy_model.noise_layer.std_estimator.module.linear.parameters()
if param.requires_grad
]
+ [
param
for param in self.noisy_model.noise_layer.mean_estimator.module.linear.parameters()
if param.requires_grad
]
),
"weight_decay": 0.0,
}
no_weight_decay_params = set(no_weight_decay_param_group["params"])
default_param_group = {
"params": [
param
for param in self.noisy_model.noise_layer.parameters()
if param.requires_grad and param not in no_weight_decay_params
],
}
return torch.optim.AdamW(
params=[no_weight_decay_param_group, default_param_group],
lr=3e-5,
amsgrad=False,
betas=(0.9, 0.95),
eps=1e-5,
weight_decay=0.1,
)
def prepare_data(self) -> None:
"""Download and prepare the dataset splits.
This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
"""
dataset = load_dataset()
training_dataset = prepare_dataset_split( # noqa: F841
dataset["train"], "train", TRAIN_TOKENIZATION_FN
)
validation_dataset = prepare_dataset_split( # noqa: F841
dataset["test"], "val", TRAIN_TOKENIZATION_FN
)
generation_dataset = prepare_dataset_split( # noqa: F841
dataset["test"], "generation", TEST_TOKENIZATION_FN
)
def setup(self, stage: str) -> None:
"""Set up the datasets."""
dataset = load_dataset()
if stage == "fit":
self.train_dataset = prepare_dataset_split(
dataset["train"], "train", TRAIN_TOKENIZATION_FN
)
self.val_dataset = prepare_dataset_split(
dataset["test"], "val", TRAIN_TOKENIZATION_FN
)
self.generation_dataset = prepare_dataset_split(
dataset["test"], "generation", TEST_TOKENIZATION_FN
)
if stage == "test":
...
if stage == "predict":
...
def train_dataloader(
self,
) -> torch.utils.data.DataLoader[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[torch.Tensor]
]:
return torch.utils.data.DataLoader(
self.train_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER, pad_to_multiple_of=8
),
batch_size=2,
shuffle=True,
num_workers=4,
pin_memory=True,
)
def val_dataloader(
self,
) -> lightning.pytorch.utilities.combined_loader.CombinedLoader:
return lightning.pytorch.utilities.combined_loader.CombinedLoader(
(
torch.utils.data.DataLoader(
self.val_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER, pad_to_multiple_of=8
),
batch_size=2,
shuffle=False,
num_workers=4,
pin_memory=True,
),
torch.utils.data.DataLoader(
self.generation_dataset, # pyright: ignore[reportArgumentType]
collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
tokenizer=TOKENIZER, pad_to_multiple_of=8
),
batch_size=2,
shuffle=False,
num_workers=4,
pin_memory=True,
),
),
mode="max_size",
)
@functools.cached_property
def is_generation_dataloader_shuffled(self) -> bool:
"""Whether the generation dataloader is shuffled.
Used to determine if we can cache and reuse base model generations from the first validation epoch.
"""
val_dataloaders = self.trainer.val_dataloaders
assert isinstance(val_dataloaders, tuple)
_, generation_dataloader = val_dataloaders
assert isinstance(generation_dataloader, torch.utils.data.DataLoader)
assert isinstance(
generation_dataloader.sampler, torch.utils.data.Sampler
)
return not (
isinstance(
generation_dataloader.sampler,
torch.utils.data.SequentialSampler,
)
or (
isinstance(
generation_dataloader.sampler,
torch.utils.data.DistributedSampler,
)
and generation_dataloader.sampler.shuffle
)
)
def forward(
self, **kwargs: Any
) -> transformers.modeling_outputs.CausalLMOutputWithPast:
with self.noisy_model.distillation_context():
return self.noisy_model(**kwargs)
def on_train_epoch_start(self) -> None:
"""Ensure that the model is truncated prior to starting training."""
self.noisy_model.truncate_and_offload()
def _training_step(
self,
batch: sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
batch_idx: int,
dataloader: torch.utils.data.DataLoader[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
]
],
num_batches: float,
mode: Literal["train", "valid"],
metrics: dict[str, float],
) -> torch.Tensor:
input_ids = batch["input_ids"]
noise_mask = batch["noise_mask"]
output = self(
input_ids=input_ids,
attention_mask=batch["attention_mask"],
use_cache=True,
noise_mask=noise_mask,
)
loss = self.distillation_loss(batch["loss_mask"])
batch_size = dataloader.batch_size
assert batch_size is not None
dataset = dataloader.dataset
assert isinstance(dataset, Sized)
num_examples_per_epoch = (num_batches - 1) * batch_size + (
len(dataset) % batch_size if dataloader.drop_last else batch_size
)
dataset_size = (
len(dataset) - len(dataset) % batch_size
if dataloader.drop_last
else len(dataset)
)
current_batch_size = (
len(dataset) % batch_size
if (batch_idx + 1) == len(dataloader) and not dataloader.drop_last
else batch_size
)
num_examples = self.trainer.world_size * (
self.current_epoch * num_examples_per_epoch
+ batch_idx * batch_size
+ current_batch_size
)
# TODO: handle intra-epoch validation
metrics[f"num_examples/{mode}/batch"] = (
num_examples # allows for head-to-head comparisons between datasets
)
metrics[f"percent_dataset/{mode}/batch"] = (
num_examples / dataset_size
) # allows for head-to-head-comparisons on the same dataset with different fractional sizes (e.g. limit_train_batches)
metrics[f"epoch/{mode}/batch"] = (
self.current_epoch + (batch_idx + 1) / num_batches
) # tracks our progress through the currently configured dataloader
obfuscation_log_step = self.obfuscation_log_step or int(
math.sqrt(num_batches)
)
if (
batch_idx % obfuscation_log_step == 0
or batch_idx == num_batches - 1
):
transformed_embeddings = self.get_transformed_embeddings()
reconstructed_input_ids = (
self.noisy_model.reconstruct_ids_from_embeddings(
transformed_embeddings
)
)
percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
input_ids, reconstructed_input_ids, noise_mask
)
if mode == "valid":
self._val_metrics.obfuscation_scores_cat.update(
percentage_changed_input_ids
)
self._val_metrics.obfuscation_quantiles.update(
percentage_changed_input_ids
)
metrics[f"obfuscation/{mode}/batch"] = (
percentage_changed_input_ids.mean().item()
)
for metric in self.train_val_metrics[
mode
].obfuscation_scores.values():
metric.update(percentage_changed_input_ids)
losses = self.get_losses()
for loss_name, loss in losses.items():
self.train_val_metrics[mode].mean_losses[loss_name].update(loss)
metrics[f"{loss_name}/{mode}/batch"] = loss.item()
if not self.noisy_model.is_truncated_and_offloaded:
logits = output["logits"][..., :-1, :].contiguous()
labels = input_ids[..., 1:].contiguous()
metrics[f"perplexity/{mode}/batch"] = (
self.train_val_metrics[mode].perplexity(logits, labels).item()
)
applied_transform_components = self.get_applied_transform_components()
noise_component_histogram_log_step = (
self.noise_component_histogram_log_step
or int(math.sqrt(num_batches))
)
if (
self.tb_writer is not None
and self.trainer.is_global_zero
and (
batch_idx % noise_component_histogram_log_step == 0
or batch_idx == num_batches - 1
)
):
for name, values in applied_transform_components.items():
self.tb_writer.add_histogram(
f"{name}_histogram/{mode}/batch",
values=values,
global_step=int(
self.current_epoch * num_batches + batch_idx
),
bins=512, # pyright: ignore[reportArgumentType]
)
return loss
def training_step(
self,
batch: sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
batch_idx: int,
) -> torch.Tensor:
"""Compute the training distillation loss."""
train_dataloader = self.trainer.train_dataloader
assert train_dataloader is not None
num_batches = self.trainer.num_training_batches
mode = "train"
step = int(self.current_epoch * num_batches + batch_idx)
metrics: dict[str, float] = {}
loss = self._training_step(
batch,
batch_idx,
train_dataloader,
num_batches=num_batches,
mode=mode,
metrics=metrics,
)
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=step)
return loss
def _on_train_epoch_end(
self,
mode: Literal["train", "valid"],
metrics: dict[str, float],
) -> None:
"""Compute epoch-level training metrics."""
metrics[f"epoch/{mode}/epoch"] = self.current_epoch
for loss_name, metric in self.train_val_metrics[
mode
].mean_losses.items():
metrics[f"{loss_name}/{mode}/epoch"] = metric.compute().item()
metric.reset()
for metric_name, metric in self.train_val_metrics[
mode
].obfuscation_scores.items():
metrics[f"{metric_name}_obfuscation/{mode}/epoch"] = (
metric.compute().item()
)
metric.reset()
if self.train_val_metrics[mode].perplexity.update_called:
metrics[f"perplexity/{mode}/epoch"] = (
self.train_val_metrics[mode].perplexity.compute().item()
)
self.train_val_metrics[mode].perplexity.reset()
def on_train_epoch_end(self) -> None:
"""Compute and log the epoch-level training metrics."""
metrics: dict[str, Any] = {}
self._on_train_epoch_end(mode="train", metrics=metrics)
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=self.current_epoch)
def on_validation_epoch_start(self) -> None:
"""Ensure that the model is fully-loaded prior to starting validation."""
self.noisy_model.restore_and_load()
def validation_step(
self,
batch: tuple[
sg_data_collator.TransformLayerTrainInputWithAttentionMask[
torch.Tensor
],
sg_data_collator.TransformLayerTestInputWithAttentionMask[
torch.Tensor
],
],
batch_idx: int,
) -> None:
"""Compute the validation distillation loss and perform generation."""
val_batch, generation_batch = batch
mode = "valid"
assert self.trainer.val_dataloaders is not None
val_dataloader = self.trainer.val_dataloaders[0]
num_batches = self.trainer.num_val_batches[0]
step = step = int(self.current_epoch * num_batches + batch_idx)
metrics: dict[str, float] = {}
_ = self._training_step(
val_batch,
batch_idx,
val_dataloader,
num_batches=num_batches,
mode=mode,
metrics=metrics,
)
if (
len(self._val_metrics.generated_text_table)
> self.max_generation_examples
):
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=step)
return
input_ids = generation_batch["input_ids"]
attention_mask = generation_batch["attention_mask"]
noise_mask = generation_batch["noise_mask"]
labels = generation_batch["labels"]
input_text = TOKENIZER.batch_decode(input_ids, skip_special_tokens=True)
labels_text = TOKENIZER.batch_decode(labels, skip_special_tokens=True)
sequence_length = input_ids.shape[-1]
generation_config = (
huggingface.generation.StainedGlassGenerationConfig.from_tokenizer(
TOKENIZER, max_length=MAX_LENGTH
)
)
if (
not self.is_generation_dataloader_shuffled
and batch_idx in self.base_model_generation_cache
):
generated_text = self.base_model_generation_cache[batch_idx]
else:
with torch.random.fork_rng(
devices=[self.device] if self.device.type != "cpu" else [],
device_type=self.device.type,
):
if self.noisy_model.noise_layer._generators is not None:
torch.manual_seed(
self.noisy_model.noise_layer.initial_seed()
)
generated_ids = self.noisy_model.base_model.generate(
inputs=input_ids,
generation_config=generation_config,
attention_mask=attention_mask,
use_cache=True,
)
generated_text = TOKENIZER.batch_decode(
generated_ids[:, sequence_length:], skip_special_tokens=True
)
self.base_model_generation_cache[batch_idx] = generated_text
generated_ids_from_transformed_embeddings, transformed_embeddings = (
self.noisy_model.generate(
inputs=input_ids,
generation_config=generation_config,
attention_mask=attention_mask,
use_cache=True,
noise_mask=noise_mask,
return_transformed_embeddings=True,
)
)
generated_text_from_transformed_embeddings = TOKENIZER.batch_decode(
generated_ids_from_transformed_embeddings[:, sequence_length:],
skip_special_tokens=True,
)
reconstructed_input_ids = (
self.noisy_model.reconstruct_ids_from_embeddings(
transformed_embeddings
)
)
reconstructed_input_text = TOKENIZER.batch_decode(
reconstructed_input_ids, skip_special_tokens=True
)
percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
input_ids, reconstructed_input_ids, noise_mask
)
rouge_scores: dict[str, torch.Tensor] = self._val_metrics.rouge(
generated_text_from_transformed_embeddings, generated_text
)
metrics.update(
{
f"{metric_name}/{mode}/batch": value.item()
for metric_name, value in rouge_scores.items()
if metric_name.endswith("_fmeasure")
}
)
self._val_metrics.generated_text_table.update(
{
"input_text": input_text,
"reconstructed_input_text": reconstructed_input_text,
"obfuscation_score": [
f"{score.item() * 100:0.1f}%"
for score in percentage_changed_input_ids
],
"labels_text": labels_text,
"generated_text": generated_text,
"generated_text_from_transformed_embeddings": generated_text_from_transformed_embeddings,
}
)
if self.logger is not None and self.trainer.is_global_zero:
self.logger.log_metrics(metrics, step=step)
def on_validation_epoch_end(self) -> None:
"""Compute and log the epoch-level validation metrics."""
step = self.current_epoch
mode = "valid"
metrics: dict[str, Any] = {}
self._on_train_epoch_end(mode=mode, metrics=metrics)
observed_obfuscation_quantiles = (
self._val_metrics.obfuscation_quantiles.compute()
)
metrics.update(
{
f"obfuscation_quantile/{q * 100:.0f}%/{mode}/epoch": value
for q, value in observed_obfuscation_quantiles.items()
}
)
self._val_metrics.obfuscation_quantiles.reset()
observed_obfuscation_scores = (
self._val_metrics.obfuscation_scores_cat.compute()
)
self._val_metrics.obfuscation_scores_cat.reset()
generated_text_table = self._val_metrics.generated_text_table.compute()
self._val_metrics.generated_text_table.reset()
if self.trainer.is_global_zero:
if self.tb_writer is not None:
self.tb_writer.add_histogram(
f"obfuscation_histogram/{mode}/epoch",
values=observed_obfuscation_scores.to("cpu", torch.float32),
global_step=step,
bins=torch.linspace(0.0, 1.0, 513), # pyright: ignore[reportArgumentType]
)
self.tb_writer.add_text(
f"generated_text/{mode}/epoch",
text_string=sg_utils.torch.tensorboard.to_markdown_table(
generated_text_table
),
global_step=step,
)
if self.wandb_logger is not None:
self.wandb_logger.log_text(
f"generated_text/{self.current_epoch}/{mode}/epoch",
columns=list(generated_text_table.keys()),
data=list(zip(*generated_text_table.values())),
)
hyperparameters = self.get_hyperparameters()
if self.wandb_run is not None:
self.wandb_run.config.update(hyperparameters)
if self.logger is not None:
self.logger.log_metrics(metrics, step=step)
# TODO: log_hyperparams doesn't currently accept a step and so the valid/epoch metrics aren't logged correctly
# Fixed by: https://github.com/Lightning-AI/pytorch-lightning/pull/20176
# self.logger.log_hyperparams(hyperparameters, metrics=metrics, step=step)
In [ ]:
Copied!
import lightning.pytorch.loggers
import wandb
import stainedglass_core
import lightning.pytorch.loggers
import wandb
import stainedglass_core
In [ ]:
Copied!
WANDB_PROJECT: Final[str] = "llm-transform-training-notebook"
SAVE_DIR: Final[str] = "saved/"
LOG_STEP: Final[int] = 100
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_float32_matmul_precision("high")
ACCELERATOR: Final[str] = "cuda"
WANDB_PROJECT: Final[str] = "llm-transform-training-notebook"
SAVE_DIR: Final[str] = "saved/"
LOG_STEP: Final[int] = 100
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_float32_matmul_precision("high")
ACCELERATOR: Final[str] = "cuda"
In [ ]:
Copied!
tb_logger = lightning.pytorch.loggers.TensorBoardLogger(
SAVE_DIR,
name="mistral_distillation",
default_hp_metric=False,
)
wandb.login()
wandb_logger = lightning.pytorch.loggers.WandbLogger(
project=WANDB_PROJECT,
save_dir=tb_logger.log_dir,
log_model=True,
magic=False,
sync_tensorboard=True,
tags=[
f"stainedglass_core=={stainedglass_core.__version__}",
],
)
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
_ = wandb_logger.experiment
trainer = lightning.Trainer(
max_epochs=1,
accumulate_grad_batches=36,
accelerator=ACCELERATOR,
strategy="auto",
precision="bf16-true",
log_every_n_steps=LOG_STEP,
default_root_dir=SAVE_DIR,
logger=[
tb_logger,
wandb_logger,
],
)
if trainer.strategy._precision_plugin is not None:
trainer.strategy._precision_plugin = sg_lightning.ReducedPrecisionFilter(
trainer.strategy._precision_plugin,
full_precision_module_types=(sg_torchmetrics.QuantileMetric,),
)
with trainer.init_module():
distillation_module = StainedGlassDistillationLightningModule(
max_generation_examples=320,
obfuscation_log_step=LOG_STEP,
noise_component_histogram_log_step=LOG_STEP,
)
trainer.fit(model=distillation_module)
tb_logger = lightning.pytorch.loggers.TensorBoardLogger(
SAVE_DIR,
name="mistral_distillation",
default_hp_metric=False,
)
wandb.login()
wandb_logger = lightning.pytorch.loggers.WandbLogger(
project=WANDB_PROJECT,
save_dir=tb_logger.log_dir,
log_model=True,
magic=False,
sync_tensorboard=True,
tags=[
f"stainedglass_core=={stainedglass_core.__version__}",
],
)
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
_ = wandb_logger.experiment
trainer = lightning.Trainer(
max_epochs=1,
accumulate_grad_batches=36,
accelerator=ACCELERATOR,
strategy="auto",
precision="bf16-true",
log_every_n_steps=LOG_STEP,
default_root_dir=SAVE_DIR,
logger=[
tb_logger,
wandb_logger,
],
)
if trainer.strategy._precision_plugin is not None:
trainer.strategy._precision_plugin = sg_lightning.ReducedPrecisionFilter(
trainer.strategy._precision_plugin,
full_precision_module_types=(sg_torchmetrics.QuantileMetric,),
)
with trainer.init_module():
distillation_module = StainedGlassDistillationLightningModule(
max_generation_examples=320,
obfuscation_log_step=LOG_STEP,
noise_component_histogram_log_step=LOG_STEP,
)
trainer.fit(model=distillation_module)