Creating Stained Glass for LLMs Recipe (with logging)¶
This notebook provides a full recipe for creating a Stained Glass Transform for LLMs with additional logging to track the evolution of metrics, parameters, generated text, and obfuscation scores during training. For a simpler introduction to creating a Stained Glass Transform, see the more streamlined tutorial version of this notebook
In [ ]:
Copied!
from __future__ import annotations
import functools
import math
import os
from collections.abc import Mapping, Sized
from typing import Any, Callable, Final, Literal
import datasets
import lightning
import lightning.pytorch.utilities.combined_loader
import numpy as np
import psutil
import torch
import torch.distributed
import torch.utils.data
import torch.utils.tensorboard
import torchmetrics
import torchmetrics.text.rouge
import transformers
import transformers.modeling_outputs
import wandb.wandb_run
from torch import nn
from stainedglass_core import (
    huggingface,
    loss as sg_loss,
    metrics as sg_metrics,
    utils as sg_utils,
)
from stainedglass_core.huggingface import (
    tokenization_utils as sg_tokenization_utils,
)
from stainedglass_core.huggingface.data import data_collator as sg_data_collator
from stainedglass_core.huggingface.tokenization_utils import noise_tokenizer
from stainedglass_core.integrations import (
    lightning as sg_lightning,
    torchmetrics as sg_torchmetrics,
)
from stainedglass_core.model import noisy_transformer_masking_model
from stainedglass_core.noise_layer import transformer_cloak
from __future__ import annotations
import functools
import math
import os
from collections.abc import Mapping, Sized
from typing import Any, Callable, Final, Literal
import datasets
import lightning
import lightning.pytorch.utilities.combined_loader
import numpy as np
import psutil
import torch
import torch.distributed
import torch.utils.data
import torch.utils.tensorboard
import torchmetrics
import torchmetrics.text.rouge
import transformers
import transformers.modeling_outputs
import wandb.wandb_run
from torch import nn
from stainedglass_core import (
    huggingface,
    loss as sg_loss,
    metrics as sg_metrics,
    utils as sg_utils,
)
from stainedglass_core.huggingface import (
    tokenization_utils as sg_tokenization_utils,
)
from stainedglass_core.huggingface.data import data_collator as sg_data_collator
from stainedglass_core.huggingface.tokenization_utils import noise_tokenizer
from stainedglass_core.integrations import (
    lightning as sg_lightning,
    torchmetrics as sg_torchmetrics,
)
from stainedglass_core.model import noisy_transformer_masking_model
from stainedglass_core.noise_layer import transformer_cloak
In [ ]:
Copied!
PHYSICAL_CPU_COUNT: Final[int] = psutil.cpu_count(logical=False)
SEED: Final[int] = 42
MODEL_TYPE: Final[type[transformers.PreTrainedModel]] = (
    transformers.MistralForCausalLM
)
PRETRAINED_MODEL_NAME_OR_PATH: Final[str] = (
    "/models/mistralai/Mistral-7B-Instruct-v0.2"
)
MAX_LENGTH: Final[int] = 4096  # length of input_ids + generated text
TRUNCATED_LAYER_INDEX: Final[int] = 12
VALIDATION_SPLIT_RATIO: Final[float | None] = 0.005
LOAD_FROM_DISK: Final[bool] = False
DATASET_NAME: Final[str] = "tatsu-lab/alpaca"
MODEL_CACHE_DIR_NAME: Final[str] = "mistral"
PHYSICAL_CPU_COUNT: Final[int] = psutil.cpu_count(logical=False)
SEED: Final[int] = 42
MODEL_TYPE: Final[type[transformers.PreTrainedModel]] = (
    transformers.MistralForCausalLM
)
PRETRAINED_MODEL_NAME_OR_PATH: Final[str] = (
    "/models/mistralai/Mistral-7B-Instruct-v0.2"
)
MAX_LENGTH: Final[int] = 4096  # length of input_ids + generated text
TRUNCATED_LAYER_INDEX: Final[int] = 12
VALIDATION_SPLIT_RATIO: Final[float | None] = 0.005
LOAD_FROM_DISK: Final[bool] = False
DATASET_NAME: Final[str] = "tatsu-lab/alpaca"
MODEL_CACHE_DIR_NAME: Final[str] = "mistral"
In [ ]:
Copied!
EXPECTED_COLUMNS: set[str] = {
    "input_ids",
    "attention_mask",
    "noise_mask",
    "loss_mask",
}
TRAIN_SCHEMA_MAPPER: Final[sg_tokenization_utils.universal.ChatSchemaMapper] = (
    sg_tokenization_utils.universal.ChatSchemaMapper(
        system_prompt_key="instruction",
        instruction_key="input",
        response_key="output",
    )
)
TRAIN_TOKENIZER: Final[transformers.PreTrainedTokenizerBase] = (
    transformers.AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
)
TRAIN_TOKENIZER.pad_token = TRAIN_TOKENIZER.eos_token
TRAIN_TOKENIZER.padding_side = "left"
noise_tokenizer_train = noise_tokenizer.NoiseTokenizer(TRAIN_TOKENIZER)
noise_tokenizer_train_fn = functools.partial(
    noise_tokenizer_train.apply_chat_template, ignore_prompt_loss=True
)
TRAIN_TOKENIZATION_FN = sg_utils.functional.sequential(
    TRAIN_SCHEMA_MAPPER, noise_tokenizer_train_fn
)
TEST_SCHEMA_MAPPER: Final[sg_tokenization_utils.universal.ChatSchemaMapper] = (
    sg_tokenization_utils.universal.ChatSchemaMapper(
        system_prompt_key="instruction",
        instruction_key="input",
        response_key=None,
    )
)
TEST_TOKENIZER: Final[transformers.PreTrainedTokenizerBase] = (
    transformers.AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
)
TEST_TOKENIZER.pad_token = TEST_TOKENIZER.eos_token
TEST_TOKENIZER.padding_side = "left"
noise_tokenizer_test = noise_tokenizer.NoiseTokenizer(TEST_TOKENIZER)
noise_tokenizer_test_fn = functools.partial(
    noise_tokenizer_train.apply_chat_template,
    ignore_prompt_loss=True,
    add_generation_prompt=True,
)
TEST_TOKENIZATION_FN = sg_utils.functional.sequential(
    TEST_SCHEMA_MAPPER, noise_tokenizer_test_fn
)
def get_cache_dir(dataset: datasets.Dataset) -> str:
    """Get the directory of the first cache file of the loaded dataset."""
    if not dataset.cache_files:
        raise ValueError("The loaded dataset has no cache files.")
    return os.path.dirname(dataset.cache_files[0]["filename"])
def max_length_filter(
    sample: (noise_tokenizer.NoiseEncoding),
) -> bool:
    """Filter `input_ids` greater than the max length."""
    return sample["input_ids"].shape[-1] <= MAX_LENGTH
def load_dataset() -> datasets.DatasetDict:
    """Load the dataset and split it into training and validation sets."""
    if LOAD_FROM_DISK:
        dataset = datasets.load_from_disk(
            DATASET_NAME,
        )
    else:
        dataset = datasets.load_dataset(
            DATASET_NAME,
            num_proc=PHYSICAL_CPU_COUNT,
        )
    assert not isinstance(
        dataset, (datasets.IterableDataset, datasets.IterableDatasetDict)
    )
    if VALIDATION_SPLIT_RATIO is not None:
        if isinstance(dataset, datasets.DatasetDict):
            assert dataset.keys() == {"train"}, (
                "VALIDATION_SPLIT_RATIO was set, but there are already multiple splits in the loaded dataset."
            )
            dataset = dataset["train"]
        cache_dir = os.path.join(get_cache_dir(dataset), "cache")
        os.makedirs(
            cache_dir, exist_ok=True
        )  # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
        train_indices_cache_file_name = os.path.join(
            cache_dir,
            f"seed_{SEED}_train_split_{1 - VALIDATION_SPLIT_RATIO:.3f}.cache",
        )
        test_indices_cache_file_name = os.path.join(
            cache_dir,
            f"seed_{SEED}_val_split_{VALIDATION_SPLIT_RATIO:.3f}.cache",
        )
        dataset = dataset.train_test_split(
            test_size=VALIDATION_SPLIT_RATIO,
            load_from_cache_file=True,
            shuffle=True,
            generator=np.random.default_rng(seed=SEED),
            train_indices_cache_file_name=train_indices_cache_file_name,
            test_indices_cache_file_name=test_indices_cache_file_name,
        )
    dataset.set_format("torch")
    return dataset
def prepare_dataset_split(
    dataset: datasets.Dataset,
    split: str,
    tokenization_fn: Callable[[Mapping[str, str]], Any],
) -> datasets.Dataset:
    """Prepare a dataset split by tokenizing and filtering the samples."""
    cache_dir = os.path.join(
        get_cache_dir(dataset), "cache", MODEL_CACHE_DIR_NAME
    )
    os.makedirs(
        cache_dir, exist_ok=True
    )  # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
    map_prefix = os.path.join(cache_dir, f"{split}-SGT-tokenized")
    filter_prefix = map_prefix + f"-max_length-{MAX_LENGTH}"
    dataset = dataset.map(
        tokenization_fn,
        cache_file_name=map_prefix + ".map",
        load_from_cache_file=True,
        batched=False,
        num_proc=PHYSICAL_CPU_COUNT,
    ).filter(
        max_length_filter,
        cache_file_name=filter_prefix + ".filter",
        load_from_cache_file=True,
        batched=False,
        num_proc=PHYSICAL_CPU_COUNT,
    )
    columns_to_drop = [
        column
        for column in dataset.column_names
        if column not in EXPECTED_COLUMNS
    ]
    return dataset.remove_columns(columns_to_drop)
class TrainingStepMetrics(nn.Module):
    """Common training metrics to be collected during [`StainedGlassDistillationLightningModule._training_step`][] in both training and
    validation.
    """
    def __init__(self) -> None:
        super().__init__()
        self.mean_losses = nn.ModuleDict(
            {
                "distillation_layer_cosine_distance_loss": torchmetrics.MeanMetric(),
                "distillation_layer_l2_distance_loss": torchmetrics.MeanMetric(),
                "normalized_input_embedding_cosine_similarity_loss": torchmetrics.MeanMetric(),
                "std_log_ratio_loss": torchmetrics.MeanMetric(),
                "composite_loss": torchmetrics.MeanMetric(),
            }
        )
        self.perplexity = torchmetrics.text.Perplexity(
            ignore_index=TRAIN_TOKENIZER.pad_token_id
        )
        self.obfuscation_scores = nn.ModuleDict(
            {
                "mean": torchmetrics.MeanMetric(),
                "min": torchmetrics.MinMetric(),
                "max": torchmetrics.MaxMetric(),
            }
        )
class ValidationStepMetrics(TrainingStepMetrics):
    """Additional metrics only to be collected during [`StainedGlassDistillationLightningModule.validation_step`][]."""
    def __init__(self) -> None:
        super().__init__()
        self.rouge = torchmetrics.text.rouge.ROUGEScore(
            rouge_keys=("rouge1", "rouge2", "rougeL")
        )
        self.obfuscation_quantiles = sg_torchmetrics.QuantileMetric(
            q=torch.linspace(0.1, 0.9, steps=9)
        )
        self.obfuscation_scores_cat = torchmetrics.CatMetric()
        self.generated_text_table = sg_torchmetrics.TableMetric()
        self.config = transformers.AutoConfig.from_pretrained(
            PRETRAINED_MODEL_NAME_OR_PATH
        )
class StainedGlassDistillationLightningModule(lightning.LightningModule):
    """[`lightning.LightningModule`][] for training an LLM Stained Glass Transform via distillation."""
    def __init__(
        self,
        max_generation_examples: int = 10,
        obfuscation_log_step: int | None = None,
        noise_component_histogram_log_step: int | None = None,
    ) -> None:
        """Initialize a `StainedGlassDistillationLightningModule`.
        Args:
            max_generation_examples: The cutoff for the number of validation examples for which to run generation. The table will have the
                nearest multiple of the batch size over this number of rows.
            obfuscation_log_step: How often to check the obfuscation score.
            noise_component_histogram_log_step: How often to log the noise component histograms (means and standard deviations) during
                training.
        """
        super().__init__()
        self.max_generation_examples = max_generation_examples
        self.obfuscation_log_step = obfuscation_log_step
        self.noise_component_histogram_log_step = (
            noise_component_histogram_log_step
        )
        self.save_hyperparameters()
        self.base_model_generation_cache: dict[int, list[str]] = {}
        """A mapping of `batch_idx` to first-epoch base model generation results that we can use to speed up the generation step of
        validation if the generation dataloader is not shuffled. Alternatively, we could do this as a dataset pre-processing step to get
        more accurate labels for training.
        """
        with sg_utils.torch.dtypes.default_dtype(torch.float32):
            self._train_metrics = TrainingStepMetrics()
            self._val_metrics = ValidationStepMetrics()
            self.train_val_metrics: dict[str, TrainingStepMetrics] = {
                "train": self._train_metrics,
                "valid": self._val_metrics,
            }
        self.config = transformers.AutoConfig.from_pretrained(
            PRETRAINED_MODEL_NAME_OR_PATH
        )
    @functools.cached_property
    def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None:  # pyright: ignore[reportPrivateImportUsage]
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
                return logger.experiment
        return None
    @functools.cached_property
    def wandb_logger(self) -> lightning.pytorch.loggers.WandbLogger | None:
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                return logger
        return None
    @functools.cached_property
    def wandb_run(self) -> wandb.wandb_run.Run | None:
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                return logger.experiment
        return None
    def configure_model(self) -> None:
        """Configure the models and loss function."""
        base_model = transformers.MistralForCausalLM.from_pretrained(
            PRETRAINED_MODEL_NAME_OR_PATH,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
        )
        assert isinstance(base_model, transformers.MistralForCausalLM)
        for param in base_model.parameters():
            param.requires_grad = False
        base_model = base_model.eval()
        self.noisy_model = (
            noisy_transformer_masking_model.NoiseMaskedNoisyTransformerModel(
                noise_layer_class=transformer_cloak.TransformerCloak,
                base_model=base_model,
                target_layer="model.embed_tokens",
                truncated_layer_index=TRUNCATED_LAYER_INDEX,
                scale=(1e-8, 1.0),
                shallow=1.0,
                mean_dropout=0.1,
                std_dropout=0.1,
                config=self.config,
                use_causal_mask=True,
                transformer_type=transformers.MistralModel,
                directly_learn_stds=True,
                rho_init=0.0,
                seed=SEED,
                noise_layer_dtype=torch.float32,
            )
        )
        self.distillation_loss, self.get_losses, self.get_hyperparameters = (
            sg_loss.distillation.distillation_loss_factory(
                self.noisy_model,
                distillation_layer_index=TRUNCATED_LAYER_INDEX,
                alpha=0.54,
                std_log_ratio_loss_weight=0.01,
                input_embedding_similarity_loss_weight=0.75,
                distillation_layer_cosine_distance_loss_weight=12.0,
            )
        )
        self.get_transformed_embeddings = (
            self.noisy_model.noise_layer.get_transformed_output_factory()
        )
        self.get_applied_transform_components = self.noisy_model.noise_layer.get_applied_transform_components_factory()
    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Configure the model optimizer."""
        no_weight_decay_param_group = {
            "params": (
                [
                    param
                    for param in self.noisy_model.noise_layer.std_estimator.module.linear.parameters()
                    if param.requires_grad
                ]
                + [
                    param
                    for param in self.noisy_model.noise_layer.mean_estimator.module.linear.parameters()
                    if param.requires_grad
                ]
            ),
            "weight_decay": 0.0,
        }
        no_weight_decay_params = set(no_weight_decay_param_group["params"])
        default_param_group = {
            "params": [
                param
                for param in self.noisy_model.noise_layer.parameters()
                if param.requires_grad and param not in no_weight_decay_params
            ],
        }
        return torch.optim.AdamW(
            params=[no_weight_decay_param_group, default_param_group],
            lr=3e-5,
            amsgrad=False,
            betas=(0.9, 0.95),
            eps=1e-5,
            weight_decay=0.1,
        )
    def prepare_data(self) -> None:
        """Download and prepare the dataset splits.
        This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
        """
        dataset = load_dataset()
        training_dataset = prepare_dataset_split(  # noqa: F841
            dataset["train"], "train", TRAIN_TOKENIZATION_FN
        )
        validation_dataset = prepare_dataset_split(  # noqa: F841
            dataset["test"], "val", TRAIN_TOKENIZATION_FN
        )
        generation_dataset = prepare_dataset_split(  # noqa: F841
            dataset["test"], "generation", TEST_TOKENIZATION_FN
        )
    def setup(self, stage: str) -> None:
        """Set up the datasets."""
        dataset = load_dataset()
        if stage == "fit":
            self.train_dataset = prepare_dataset_split(
                dataset["train"], "train", TRAIN_TOKENIZATION_FN
            )
            self.val_dataset = prepare_dataset_split(
                dataset["test"], "val", TRAIN_TOKENIZATION_FN
            )
            self.generation_dataset = prepare_dataset_split(
                dataset["test"], "generation", TEST_TOKENIZATION_FN
            )
        if stage == "test":
            ...
        if stage == "predict":
            ...
    def train_dataloader(
        self,
    ) -> torch.utils.data.DataLoader[noise_tokenizer.NoiseEncoding]:
        return torch.utils.data.DataLoader(
            self.train_dataset,  # pyright: ignore[reportArgumentType]
            collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
                tokenizer=TRAIN_TOKENIZER, pad_to_multiple_of=8
            ),
            batch_size=2,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )
    def val_dataloader(
        self,
    ) -> lightning.pytorch.utilities.combined_loader.CombinedLoader:
        return lightning.pytorch.utilities.combined_loader.CombinedLoader(
            (
                torch.utils.data.DataLoader(
                    self.val_dataset,  # pyright: ignore[reportArgumentType]
                    collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
                        tokenizer=TRAIN_TOKENIZER, pad_to_multiple_of=8
                    ),
                    batch_size=2,
                    shuffle=False,
                    num_workers=4,
                    pin_memory=True,
                ),
                torch.utils.data.DataLoader(
                    self.generation_dataset,  # pyright: ignore[reportArgumentType]
                    collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
                        tokenizer=TRAIN_TOKENIZER, pad_to_multiple_of=8
                    ),
                    batch_size=2,
                    shuffle=False,
                    num_workers=4,
                    pin_memory=True,
                ),
            ),
            mode="min_size",
        )
    @functools.cached_property
    def is_generation_dataloader_shuffled(self) -> bool:
        """Whether the generation dataloader is shuffled.
        Used to determine if we can cache and reuse base model generations from the first validation epoch.
        """
        val_dataloaders = self.trainer.val_dataloaders
        assert isinstance(val_dataloaders, tuple)
        _, generation_dataloader = val_dataloaders
        assert isinstance(generation_dataloader, torch.utils.data.DataLoader)
        assert isinstance(
            generation_dataloader.sampler, torch.utils.data.Sampler
        )
        return not (
            isinstance(
                generation_dataloader.sampler,
                torch.utils.data.SequentialSampler,
            )
            or (
                isinstance(
                    generation_dataloader.sampler,
                    torch.utils.data.DistributedSampler,
                )
                and generation_dataloader.sampler.shuffle
            )
        )
    def forward(
        self, **kwargs: Any
    ) -> transformers.modeling_outputs.CausalLMOutputWithPast:
        with self.noisy_model.distillation_context():
            return self.noisy_model(**kwargs)
    def on_train_epoch_start(self) -> None:
        """Ensure that the model is truncated prior to starting training."""
        self.noisy_model.truncate_and_offload()
    def _training_step(
        self,
        batch: noise_tokenizer.NoiseEncoding,
        batch_idx: int,
        dataloader: torch.utils.data.DataLoader[noise_tokenizer.NoiseEncoding],
        num_batches: float,
        mode: Literal["train", "valid"],
        metrics: dict[str, float],
    ) -> torch.Tensor:
        input_ids = batch["input_ids"]
        noise_mask = batch["noise_mask"]
        assert "attention_mask" in batch
        assert "loss_mask" in batch
        _ = self(
            input_ids=input_ids,
            attention_mask=batch["attention_mask"],
            use_cache=True,
            noise_mask=noise_mask,
        )
        loss = self.distillation_loss(batch["loss_mask"])
        batch_size = dataloader.batch_size
        assert batch_size is not None
        dataset = dataloader.dataset
        assert isinstance(dataset, Sized)
        num_examples_per_epoch = (num_batches - 1) * batch_size + (
            len(dataset) % batch_size if dataloader.drop_last else batch_size
        )
        dataset_size = (
            len(dataset) - len(dataset) % batch_size
            if dataloader.drop_last
            else len(dataset)
        )
        current_batch_size = (
            len(dataset) % batch_size
            if (batch_idx + 1) == len(dataloader) and not dataloader.drop_last
            else batch_size
        )
        num_examples = self.trainer.world_size * (
            self.current_epoch * num_examples_per_epoch
            + batch_idx * batch_size
            + current_batch_size
        )
        # TODO: handle intra-epoch validation
        metrics[f"num_examples/{mode}/batch"] = (
            num_examples  # allows for head-to-head comparisons between datasets
        )
        metrics[f"percent_dataset/{mode}/batch"] = (
            num_examples / dataset_size
        )  # allows for head-to-head-comparisons on the same dataset with different fractional sizes (e.g. limit_train_batches)
        metrics[f"epoch/{mode}/batch"] = (
            self.current_epoch + (batch_idx + 1) / num_batches
        )  # tracks our progress through the currently configured dataloader
        obfuscation_log_step = self.obfuscation_log_step or int(
            math.sqrt(num_batches)
        )
        if (
            batch_idx % obfuscation_log_step == 0
            or batch_idx == num_batches - 1
        ):
            transformed_embeddings = self.get_transformed_embeddings()
            reconstructed_input_ids = (
                self.noisy_model.reconstruct_ids_from_embeddings(
                    transformed_embeddings
                )
            )
            percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
                input_ids, reconstructed_input_ids, noise_mask
            )
            if mode == "valid":
                self._val_metrics.obfuscation_scores_cat.update(
                    percentage_changed_input_ids
                )
                self._val_metrics.obfuscation_quantiles.update(
                    percentage_changed_input_ids
                )
            metrics[f"obfuscation/{mode}/batch"] = (
                percentage_changed_input_ids.mean().item()
            )
            for metric in self.train_val_metrics[
                mode
            ].obfuscation_scores.values():
                metric.update(percentage_changed_input_ids)
        losses = self.get_losses()
        for loss_name, loss in losses.items():
            self.train_val_metrics[mode].mean_losses[loss_name].update(loss)
            metrics[f"{loss_name}/{mode}/batch"] = loss.item()
        applied_transform_components = self.get_applied_transform_components()
        noise_component_histogram_log_step = (
            self.noise_component_histogram_log_step
            or int(math.sqrt(num_batches))
        )
        if (
            self.tb_writer is not None
            and self.trainer.is_global_zero
            and (
                batch_idx % noise_component_histogram_log_step == 0
                or batch_idx == num_batches - 1
            )
        ):
            for name, values in applied_transform_components.items():
                self.tb_writer.add_histogram(
                    f"{name}_histogram/{mode}/batch",
                    values=values,
                    global_step=int(
                        self.current_epoch * num_batches + batch_idx
                    ),
                    bins=512,  # pyright: ignore[reportArgumentType]
                )
        return loss
    def training_step(
        self,
        batch: noise_tokenizer.NoiseEncoding,
        batch_idx: int,
    ) -> torch.Tensor:
        """Compute the training distillation loss."""
        train_dataloader = self.trainer.train_dataloader
        assert train_dataloader is not None
        num_batches = self.trainer.num_training_batches
        mode = "train"
        step = int(self.current_epoch * num_batches + batch_idx)
        metrics: dict[str, float] = {}
        loss = self._training_step(
            batch,
            batch_idx,
            train_dataloader,
            num_batches=num_batches,
            mode=mode,
            metrics=metrics,
        )
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(metrics, step=step)
        return loss
    def _on_train_epoch_end(
        self,
        mode: Literal["train", "valid"],
        metrics: dict[str, float],
    ) -> None:
        """Compute epoch-level training metrics."""
        metrics[f"epoch/{mode}/epoch"] = self.current_epoch
        for loss_name, metric in self.train_val_metrics[
            mode
        ].mean_losses.items():
            metrics[f"{loss_name}/{mode}/epoch"] = metric.compute().item()
            metric.reset()
        for metric_name, metric in self.train_val_metrics[
            mode
        ].obfuscation_scores.items():
            metrics[f"{metric_name}_obfuscation/{mode}/epoch"] = (
                metric.compute().item()
            )
            metric.reset()
        if self.train_val_metrics[mode].perplexity.update_called:
            metrics[f"perplexity/{mode}/epoch"] = (
                self.train_val_metrics[mode].perplexity.compute().item()
            )
            self.train_val_metrics[mode].perplexity.reset()
    def on_train_epoch_end(self) -> None:
        """Compute and log the epoch-level training metrics."""
        metrics: dict[str, Any] = {}
        self._on_train_epoch_end(mode="train", metrics=metrics)
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(metrics, step=self.current_epoch)
    def on_validation_epoch_start(self) -> None:
        """Ensure that the model is fully-loaded prior to starting validation."""
        self.noisy_model.restore_and_load()
    def validation_step(
        self,
        batch: tuple[
            noise_tokenizer.NoiseEncoding, noise_tokenizer.NoiseEncoding
        ],
        batch_idx: int,
    ) -> None:
        """Compute the validation distillation loss and perform generation."""
        val_batch, generation_batch = batch
        mode = "valid"
        assert self.trainer.val_dataloaders is not None
        val_dataloader = self.trainer.val_dataloaders[0]
        num_batches = self.trainer.num_val_batches[0]
        step = step = int(self.current_epoch * num_batches + batch_idx)
        metrics: dict[str, float] = {}
        _ = self._training_step(
            val_batch,
            batch_idx,
            val_dataloader,
            num_batches=num_batches,
            mode=mode,
            metrics=metrics,
        )
        if (
            len(self._val_metrics.generated_text_table)
            > self.max_generation_examples
        ):
            if self.logger is not None and self.trainer.is_global_zero:
                self.logger.log_metrics(metrics, step=step)
            return
        input_ids = generation_batch["input_ids"]
        attention_mask = generation_batch["attention_mask"]
        noise_mask = generation_batch["noise_mask"]
        input_text = TRAIN_TOKENIZER.batch_decode(
            input_ids, skip_special_tokens=True
        )
        sequence_length = input_ids.shape[-1]
        generation_config = (
            huggingface.generation.StainedGlassGenerationConfig.from_tokenizer(
                TRAIN_TOKENIZER, max_length=MAX_LENGTH
            )
        )
        if (
            not self.is_generation_dataloader_shuffled
            and batch_idx in self.base_model_generation_cache
        ):
            generated_text = self.base_model_generation_cache[batch_idx]
        else:
            with torch.random.fork_rng(
                devices=[self.device] if self.device.type != "cpu" else [],
                device_type=self.device.type,
            ):
                if self.noisy_model.noise_layer._generators is not None:
                    torch.manual_seed(
                        self.noisy_model.noise_layer.initial_seed()
                    )
                generated_ids = self.noisy_model.base_model.generate(
                    inputs=input_ids,
                    generation_config=generation_config,
                    attention_mask=attention_mask,
                    use_cache=True,
                )
            generated_text = TRAIN_TOKENIZER.batch_decode(
                generated_ids[:, sequence_length:], skip_special_tokens=True
            )
            self.base_model_generation_cache[batch_idx] = generated_text
        generated_ids_from_transformed_embeddings, transformed_embeddings = (
            self.noisy_model.generate(
                inputs=input_ids,
                generation_config=generation_config,
                attention_mask=attention_mask,
                use_cache=True,
                noise_mask=noise_mask,
                return_transformed_embeddings=True,
            )
        )
        generated_text_from_transformed_embeddings = (
            TRAIN_TOKENIZER.batch_decode(
                generated_ids_from_transformed_embeddings[:, sequence_length:],
                skip_special_tokens=True,
            )
        )
        reconstructed_input_ids = (
            self.noisy_model.reconstruct_ids_from_embeddings(
                transformed_embeddings
            )
        )
        reconstructed_input_text = TRAIN_TOKENIZER.batch_decode(
            reconstructed_input_ids, skip_special_tokens=True
        )
        percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
            input_ids, reconstructed_input_ids, noise_mask
        )
        rouge_scores: dict[str, torch.Tensor] = self._val_metrics.rouge(
            generated_text_from_transformed_embeddings, generated_text
        )
        metrics.update(
            {
                f"{metric_name}/{mode}/batch": value.item()
                for metric_name, value in rouge_scores.items()
                if metric_name.endswith("_fmeasure")
            }
        )
        self._val_metrics.generated_text_table.update(
            {
                "input_text": input_text,
                "reconstructed_input_text": reconstructed_input_text,
                "obfuscation_score": [
                    f"{score.item() * 100:0.1f}%"
                    for score in percentage_changed_input_ids
                ],
                "generated_text": generated_text,
                "generated_text_from_transformed_embeddings": generated_text_from_transformed_embeddings,
            }
        )
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(metrics, step=step)
    def on_validation_epoch_end(self) -> None:
        """Compute and log the epoch-level validation metrics."""
        step = self.current_epoch
        mode = "valid"
        metrics: dict[str, Any] = {}
        self._on_train_epoch_end(mode=mode, metrics=metrics)
        observations, obfuscation_quantiles = (
            self._val_metrics.obfuscation_quantiles.compute()
        )
        metrics.update(
            {
                f"obfuscation_quantile/{q * 100:.0f}%/{mode}/epoch": value
                for q, value in obfuscation_quantiles.items()
            }
        )
        self._val_metrics.obfuscation_quantiles.reset()
        observed_obfuscation_scores = (
            self._val_metrics.obfuscation_scores_cat.compute()
        )
        self._val_metrics.obfuscation_scores_cat.reset()
        generated_text_table = self._val_metrics.generated_text_table.compute()
        self._val_metrics.generated_text_table.reset()
        if self.trainer.is_global_zero:
            if self.tb_writer is not None:
                self.tb_writer.add_histogram(
                    f"obfuscation_histogram/{mode}/epoch",
                    values=observed_obfuscation_scores.to("cpu", torch.float32),
                    global_step=step,
                    bins=torch.linspace(0.0, 1.0, 513),  # pyright: ignore[reportArgumentType]
                )
                self.tb_writer.add_text(
                    f"generated_text/{mode}/epoch",
                    text_string=sg_utils.torch.tensorboard.to_markdown_table(
                        generated_text_table
                    ),
                    global_step=step,
                )
            if self.wandb_logger is not None:
                self.wandb_logger.log_text(
                    f"generated_text/{self.current_epoch}/{mode}/epoch",
                    columns=list(generated_text_table.keys()),
                    data=list(zip(*generated_text_table.values())),
                )
            hyperparameters = self.get_hyperparameters()
            if self.wandb_run is not None:
                self.wandb_run.config.update(hyperparameters)
            if self.logger is not None:
                self.logger.log_metrics(metrics, step=step)
                # TODO: log_hyperparams doesn't currently accept a step and so the valid/epoch metrics aren't logged correctly
                # Fixed by: https://github.com/Lightning-AI/pytorch-lightning/pull/20176
                # self.logger.log_hyperparams(hyperparameters, metrics=metrics, step=step)
EXPECTED_COLUMNS: set[str] = {
    "input_ids",
    "attention_mask",
    "noise_mask",
    "loss_mask",
}
TRAIN_SCHEMA_MAPPER: Final[sg_tokenization_utils.universal.ChatSchemaMapper] = (
    sg_tokenization_utils.universal.ChatSchemaMapper(
        system_prompt_key="instruction",
        instruction_key="input",
        response_key="output",
    )
)
TRAIN_TOKENIZER: Final[transformers.PreTrainedTokenizerBase] = (
    transformers.AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
)
TRAIN_TOKENIZER.pad_token = TRAIN_TOKENIZER.eos_token
TRAIN_TOKENIZER.padding_side = "left"
noise_tokenizer_train = noise_tokenizer.NoiseTokenizer(TRAIN_TOKENIZER)
noise_tokenizer_train_fn = functools.partial(
    noise_tokenizer_train.apply_chat_template, ignore_prompt_loss=True
)
TRAIN_TOKENIZATION_FN = sg_utils.functional.sequential(
    TRAIN_SCHEMA_MAPPER, noise_tokenizer_train_fn
)
TEST_SCHEMA_MAPPER: Final[sg_tokenization_utils.universal.ChatSchemaMapper] = (
    sg_tokenization_utils.universal.ChatSchemaMapper(
        system_prompt_key="instruction",
        instruction_key="input",
        response_key=None,
    )
)
TEST_TOKENIZER: Final[transformers.PreTrainedTokenizerBase] = (
    transformers.AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
)
TEST_TOKENIZER.pad_token = TEST_TOKENIZER.eos_token
TEST_TOKENIZER.padding_side = "left"
noise_tokenizer_test = noise_tokenizer.NoiseTokenizer(TEST_TOKENIZER)
noise_tokenizer_test_fn = functools.partial(
    noise_tokenizer_train.apply_chat_template,
    ignore_prompt_loss=True,
    add_generation_prompt=True,
)
TEST_TOKENIZATION_FN = sg_utils.functional.sequential(
    TEST_SCHEMA_MAPPER, noise_tokenizer_test_fn
)
def get_cache_dir(dataset: datasets.Dataset) -> str:
    """Get the directory of the first cache file of the loaded dataset."""
    if not dataset.cache_files:
        raise ValueError("The loaded dataset has no cache files.")
    return os.path.dirname(dataset.cache_files[0]["filename"])
def max_length_filter(
    sample: (noise_tokenizer.NoiseEncoding),
) -> bool:
    """Filter `input_ids` greater than the max length."""
    return sample["input_ids"].shape[-1] <= MAX_LENGTH
def load_dataset() -> datasets.DatasetDict:
    """Load the dataset and split it into training and validation sets."""
    if LOAD_FROM_DISK:
        dataset = datasets.load_from_disk(
            DATASET_NAME,
        )
    else:
        dataset = datasets.load_dataset(
            DATASET_NAME,
            num_proc=PHYSICAL_CPU_COUNT,
        )
    assert not isinstance(
        dataset, (datasets.IterableDataset, datasets.IterableDatasetDict)
    )
    if VALIDATION_SPLIT_RATIO is not None:
        if isinstance(dataset, datasets.DatasetDict):
            assert dataset.keys() == {"train"}, (
                "VALIDATION_SPLIT_RATIO was set, but there are already multiple splits in the loaded dataset."
            )
            dataset = dataset["train"]
        cache_dir = os.path.join(get_cache_dir(dataset), "cache")
        os.makedirs(
            cache_dir, exist_ok=True
        )  # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
        train_indices_cache_file_name = os.path.join(
            cache_dir,
            f"seed_{SEED}_train_split_{1 - VALIDATION_SPLIT_RATIO:.3f}.cache",
        )
        test_indices_cache_file_name = os.path.join(
            cache_dir,
            f"seed_{SEED}_val_split_{VALIDATION_SPLIT_RATIO:.3f}.cache",
        )
        dataset = dataset.train_test_split(
            test_size=VALIDATION_SPLIT_RATIO,
            load_from_cache_file=True,
            shuffle=True,
            generator=np.random.default_rng(seed=SEED),
            train_indices_cache_file_name=train_indices_cache_file_name,
            test_indices_cache_file_name=test_indices_cache_file_name,
        )
    dataset.set_format("torch")
    return dataset
def prepare_dataset_split(
    dataset: datasets.Dataset,
    split: str,
    tokenization_fn: Callable[[Mapping[str, str]], Any],
) -> datasets.Dataset:
    """Prepare a dataset split by tokenizing and filtering the samples."""
    cache_dir = os.path.join(
        get_cache_dir(dataset), "cache", MODEL_CACHE_DIR_NAME
    )
    os.makedirs(
        cache_dir, exist_ok=True
    )  # TODO: datasets should probably handle this for us https://github.com/huggingface/datasets/pull/7096
    map_prefix = os.path.join(cache_dir, f"{split}-SGT-tokenized")
    filter_prefix = map_prefix + f"-max_length-{MAX_LENGTH}"
    dataset = dataset.map(
        tokenization_fn,
        cache_file_name=map_prefix + ".map",
        load_from_cache_file=True,
        batched=False,
        num_proc=PHYSICAL_CPU_COUNT,
    ).filter(
        max_length_filter,
        cache_file_name=filter_prefix + ".filter",
        load_from_cache_file=True,
        batched=False,
        num_proc=PHYSICAL_CPU_COUNT,
    )
    columns_to_drop = [
        column
        for column in dataset.column_names
        if column not in EXPECTED_COLUMNS
    ]
    return dataset.remove_columns(columns_to_drop)
class TrainingStepMetrics(nn.Module):
    """Common training metrics to be collected during [`StainedGlassDistillationLightningModule._training_step`][] in both training and
    validation.
    """
    def __init__(self) -> None:
        super().__init__()
        self.mean_losses = nn.ModuleDict(
            {
                "distillation_layer_cosine_distance_loss": torchmetrics.MeanMetric(),
                "distillation_layer_l2_distance_loss": torchmetrics.MeanMetric(),
                "normalized_input_embedding_cosine_similarity_loss": torchmetrics.MeanMetric(),
                "std_log_ratio_loss": torchmetrics.MeanMetric(),
                "composite_loss": torchmetrics.MeanMetric(),
            }
        )
        self.perplexity = torchmetrics.text.Perplexity(
            ignore_index=TRAIN_TOKENIZER.pad_token_id
        )
        self.obfuscation_scores = nn.ModuleDict(
            {
                "mean": torchmetrics.MeanMetric(),
                "min": torchmetrics.MinMetric(),
                "max": torchmetrics.MaxMetric(),
            }
        )
class ValidationStepMetrics(TrainingStepMetrics):
    """Additional metrics only to be collected during [`StainedGlassDistillationLightningModule.validation_step`][]."""
    def __init__(self) -> None:
        super().__init__()
        self.rouge = torchmetrics.text.rouge.ROUGEScore(
            rouge_keys=("rouge1", "rouge2", "rougeL")
        )
        self.obfuscation_quantiles = sg_torchmetrics.QuantileMetric(
            q=torch.linspace(0.1, 0.9, steps=9)
        )
        self.obfuscation_scores_cat = torchmetrics.CatMetric()
        self.generated_text_table = sg_torchmetrics.TableMetric()
        self.config = transformers.AutoConfig.from_pretrained(
            PRETRAINED_MODEL_NAME_OR_PATH
        )
class StainedGlassDistillationLightningModule(lightning.LightningModule):
    """[`lightning.LightningModule`][] for training an LLM Stained Glass Transform via distillation."""
    def __init__(
        self,
        max_generation_examples: int = 10,
        obfuscation_log_step: int | None = None,
        noise_component_histogram_log_step: int | None = None,
    ) -> None:
        """Initialize a `StainedGlassDistillationLightningModule`.
        Args:
            max_generation_examples: The cutoff for the number of validation examples for which to run generation. The table will have the
                nearest multiple of the batch size over this number of rows.
            obfuscation_log_step: How often to check the obfuscation score.
            noise_component_histogram_log_step: How often to log the noise component histograms (means and standard deviations) during
                training.
        """
        super().__init__()
        self.max_generation_examples = max_generation_examples
        self.obfuscation_log_step = obfuscation_log_step
        self.noise_component_histogram_log_step = (
            noise_component_histogram_log_step
        )
        self.save_hyperparameters()
        self.base_model_generation_cache: dict[int, list[str]] = {}
        """A mapping of `batch_idx` to first-epoch base model generation results that we can use to speed up the generation step of
        validation if the generation dataloader is not shuffled. Alternatively, we could do this as a dataset pre-processing step to get
        more accurate labels for training.
        """
        with sg_utils.torch.dtypes.default_dtype(torch.float32):
            self._train_metrics = TrainingStepMetrics()
            self._val_metrics = ValidationStepMetrics()
            self.train_val_metrics: dict[str, TrainingStepMetrics] = {
                "train": self._train_metrics,
                "valid": self._val_metrics,
            }
        self.config = transformers.AutoConfig.from_pretrained(
            PRETRAINED_MODEL_NAME_OR_PATH
        )
    @functools.cached_property
    def tb_writer(self) -> torch.utils.tensorboard.SummaryWriter | None:  # pyright: ignore[reportPrivateImportUsage]
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.TensorBoardLogger):
                return logger.experiment
        return None
    @functools.cached_property
    def wandb_logger(self) -> lightning.pytorch.loggers.WandbLogger | None:
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                return logger
        return None
    @functools.cached_property
    def wandb_run(self) -> wandb.wandb_run.Run | None:
        assert self.loggers is not None
        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                return logger.experiment
        return None
    def configure_model(self) -> None:
        """Configure the models and loss function."""
        base_model = transformers.MistralForCausalLM.from_pretrained(
            PRETRAINED_MODEL_NAME_OR_PATH,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
        )
        assert isinstance(base_model, transformers.MistralForCausalLM)
        for param in base_model.parameters():
            param.requires_grad = False
        base_model = base_model.eval()
        self.noisy_model = (
            noisy_transformer_masking_model.NoiseMaskedNoisyTransformerModel(
                noise_layer_class=transformer_cloak.TransformerCloak,
                base_model=base_model,
                target_layer="model.embed_tokens",
                truncated_layer_index=TRUNCATED_LAYER_INDEX,
                scale=(1e-8, 1.0),
                shallow=1.0,
                mean_dropout=0.1,
                std_dropout=0.1,
                config=self.config,
                use_causal_mask=True,
                transformer_type=transformers.MistralModel,
                directly_learn_stds=True,
                rho_init=0.0,
                seed=SEED,
                noise_layer_dtype=torch.float32,
            )
        )
        self.distillation_loss, self.get_losses, self.get_hyperparameters = (
            sg_loss.distillation.distillation_loss_factory(
                self.noisy_model,
                distillation_layer_index=TRUNCATED_LAYER_INDEX,
                alpha=0.54,
                std_log_ratio_loss_weight=0.01,
                input_embedding_similarity_loss_weight=0.75,
                distillation_layer_cosine_distance_loss_weight=12.0,
            )
        )
        self.get_transformed_embeddings = (
            self.noisy_model.noise_layer.get_transformed_output_factory()
        )
        self.get_applied_transform_components = self.noisy_model.noise_layer.get_applied_transform_components_factory()
    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Configure the model optimizer."""
        no_weight_decay_param_group = {
            "params": (
                [
                    param
                    for param in self.noisy_model.noise_layer.std_estimator.module.linear.parameters()
                    if param.requires_grad
                ]
                + [
                    param
                    for param in self.noisy_model.noise_layer.mean_estimator.module.linear.parameters()
                    if param.requires_grad
                ]
            ),
            "weight_decay": 0.0,
        }
        no_weight_decay_params = set(no_weight_decay_param_group["params"])
        default_param_group = {
            "params": [
                param
                for param in self.noisy_model.noise_layer.parameters()
                if param.requires_grad and param not in no_weight_decay_params
            ],
        }
        return torch.optim.AdamW(
            params=[no_weight_decay_param_group, default_param_group],
            lr=3e-5,
            amsgrad=False,
            betas=(0.9, 0.95),
            eps=1e-5,
            weight_decay=0.1,
        )
    def prepare_data(self) -> None:
        """Download and prepare the dataset splits.
        This function is only called on the local rank 0 process, so we won't attempt to download or map dataset multiple times here.
        """
        dataset = load_dataset()
        training_dataset = prepare_dataset_split(  # noqa: F841
            dataset["train"], "train", TRAIN_TOKENIZATION_FN
        )
        validation_dataset = prepare_dataset_split(  # noqa: F841
            dataset["test"], "val", TRAIN_TOKENIZATION_FN
        )
        generation_dataset = prepare_dataset_split(  # noqa: F841
            dataset["test"], "generation", TEST_TOKENIZATION_FN
        )
    def setup(self, stage: str) -> None:
        """Set up the datasets."""
        dataset = load_dataset()
        if stage == "fit":
            self.train_dataset = prepare_dataset_split(
                dataset["train"], "train", TRAIN_TOKENIZATION_FN
            )
            self.val_dataset = prepare_dataset_split(
                dataset["test"], "val", TRAIN_TOKENIZATION_FN
            )
            self.generation_dataset = prepare_dataset_split(
                dataset["test"], "generation", TEST_TOKENIZATION_FN
            )
        if stage == "test":
            ...
        if stage == "predict":
            ...
    def train_dataloader(
        self,
    ) -> torch.utils.data.DataLoader[noise_tokenizer.NoiseEncoding]:
        return torch.utils.data.DataLoader(
            self.train_dataset,  # pyright: ignore[reportArgumentType]
            collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
                tokenizer=TRAIN_TOKENIZER, pad_to_multiple_of=8
            ),
            batch_size=2,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )
    def val_dataloader(
        self,
    ) -> lightning.pytorch.utilities.combined_loader.CombinedLoader:
        return lightning.pytorch.utilities.combined_loader.CombinedLoader(
            (
                torch.utils.data.DataLoader(
                    self.val_dataset,  # pyright: ignore[reportArgumentType]
                    collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
                        tokenizer=TRAIN_TOKENIZER, pad_to_multiple_of=8
                    ),
                    batch_size=2,
                    shuffle=False,
                    num_workers=4,
                    pin_memory=True,
                ),
                torch.utils.data.DataLoader(
                    self.generation_dataset,  # pyright: ignore[reportArgumentType]
                    collate_fn=sg_data_collator.DataCollatorForStainedGlassSeq2Seq(
                        tokenizer=TRAIN_TOKENIZER, pad_to_multiple_of=8
                    ),
                    batch_size=2,
                    shuffle=False,
                    num_workers=4,
                    pin_memory=True,
                ),
            ),
            mode="min_size",
        )
    @functools.cached_property
    def is_generation_dataloader_shuffled(self) -> bool:
        """Whether the generation dataloader is shuffled.
        Used to determine if we can cache and reuse base model generations from the first validation epoch.
        """
        val_dataloaders = self.trainer.val_dataloaders
        assert isinstance(val_dataloaders, tuple)
        _, generation_dataloader = val_dataloaders
        assert isinstance(generation_dataloader, torch.utils.data.DataLoader)
        assert isinstance(
            generation_dataloader.sampler, torch.utils.data.Sampler
        )
        return not (
            isinstance(
                generation_dataloader.sampler,
                torch.utils.data.SequentialSampler,
            )
            or (
                isinstance(
                    generation_dataloader.sampler,
                    torch.utils.data.DistributedSampler,
                )
                and generation_dataloader.sampler.shuffle
            )
        )
    def forward(
        self, **kwargs: Any
    ) -> transformers.modeling_outputs.CausalLMOutputWithPast:
        with self.noisy_model.distillation_context():
            return self.noisy_model(**kwargs)
    def on_train_epoch_start(self) -> None:
        """Ensure that the model is truncated prior to starting training."""
        self.noisy_model.truncate_and_offload()
    def _training_step(
        self,
        batch: noise_tokenizer.NoiseEncoding,
        batch_idx: int,
        dataloader: torch.utils.data.DataLoader[noise_tokenizer.NoiseEncoding],
        num_batches: float,
        mode: Literal["train", "valid"],
        metrics: dict[str, float],
    ) -> torch.Tensor:
        input_ids = batch["input_ids"]
        noise_mask = batch["noise_mask"]
        assert "attention_mask" in batch
        assert "loss_mask" in batch
        _ = self(
            input_ids=input_ids,
            attention_mask=batch["attention_mask"],
            use_cache=True,
            noise_mask=noise_mask,
        )
        loss = self.distillation_loss(batch["loss_mask"])
        batch_size = dataloader.batch_size
        assert batch_size is not None
        dataset = dataloader.dataset
        assert isinstance(dataset, Sized)
        num_examples_per_epoch = (num_batches - 1) * batch_size + (
            len(dataset) % batch_size if dataloader.drop_last else batch_size
        )
        dataset_size = (
            len(dataset) - len(dataset) % batch_size
            if dataloader.drop_last
            else len(dataset)
        )
        current_batch_size = (
            len(dataset) % batch_size
            if (batch_idx + 1) == len(dataloader) and not dataloader.drop_last
            else batch_size
        )
        num_examples = self.trainer.world_size * (
            self.current_epoch * num_examples_per_epoch
            + batch_idx * batch_size
            + current_batch_size
        )
        # TODO: handle intra-epoch validation
        metrics[f"num_examples/{mode}/batch"] = (
            num_examples  # allows for head-to-head comparisons between datasets
        )
        metrics[f"percent_dataset/{mode}/batch"] = (
            num_examples / dataset_size
        )  # allows for head-to-head-comparisons on the same dataset with different fractional sizes (e.g. limit_train_batches)
        metrics[f"epoch/{mode}/batch"] = (
            self.current_epoch + (batch_idx + 1) / num_batches
        )  # tracks our progress through the currently configured dataloader
        obfuscation_log_step = self.obfuscation_log_step or int(
            math.sqrt(num_batches)
        )
        if (
            batch_idx % obfuscation_log_step == 0
            or batch_idx == num_batches - 1
        ):
            transformed_embeddings = self.get_transformed_embeddings()
            reconstructed_input_ids = (
                self.noisy_model.reconstruct_ids_from_embeddings(
                    transformed_embeddings
                )
            )
            percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
                input_ids, reconstructed_input_ids, noise_mask
            )
            if mode == "valid":
                self._val_metrics.obfuscation_scores_cat.update(
                    percentage_changed_input_ids
                )
                self._val_metrics.obfuscation_quantiles.update(
                    percentage_changed_input_ids
                )
            metrics[f"obfuscation/{mode}/batch"] = (
                percentage_changed_input_ids.mean().item()
            )
            for metric in self.train_val_metrics[
                mode
            ].obfuscation_scores.values():
                metric.update(percentage_changed_input_ids)
        losses = self.get_losses()
        for loss_name, loss in losses.items():
            self.train_val_metrics[mode].mean_losses[loss_name].update(loss)
            metrics[f"{loss_name}/{mode}/batch"] = loss.item()
        applied_transform_components = self.get_applied_transform_components()
        noise_component_histogram_log_step = (
            self.noise_component_histogram_log_step
            or int(math.sqrt(num_batches))
        )
        if (
            self.tb_writer is not None
            and self.trainer.is_global_zero
            and (
                batch_idx % noise_component_histogram_log_step == 0
                or batch_idx == num_batches - 1
            )
        ):
            for name, values in applied_transform_components.items():
                self.tb_writer.add_histogram(
                    f"{name}_histogram/{mode}/batch",
                    values=values,
                    global_step=int(
                        self.current_epoch * num_batches + batch_idx
                    ),
                    bins=512,  # pyright: ignore[reportArgumentType]
                )
        return loss
    def training_step(
        self,
        batch: noise_tokenizer.NoiseEncoding,
        batch_idx: int,
    ) -> torch.Tensor:
        """Compute the training distillation loss."""
        train_dataloader = self.trainer.train_dataloader
        assert train_dataloader is not None
        num_batches = self.trainer.num_training_batches
        mode = "train"
        step = int(self.current_epoch * num_batches + batch_idx)
        metrics: dict[str, float] = {}
        loss = self._training_step(
            batch,
            batch_idx,
            train_dataloader,
            num_batches=num_batches,
            mode=mode,
            metrics=metrics,
        )
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(metrics, step=step)
        return loss
    def _on_train_epoch_end(
        self,
        mode: Literal["train", "valid"],
        metrics: dict[str, float],
    ) -> None:
        """Compute epoch-level training metrics."""
        metrics[f"epoch/{mode}/epoch"] = self.current_epoch
        for loss_name, metric in self.train_val_metrics[
            mode
        ].mean_losses.items():
            metrics[f"{loss_name}/{mode}/epoch"] = metric.compute().item()
            metric.reset()
        for metric_name, metric in self.train_val_metrics[
            mode
        ].obfuscation_scores.items():
            metrics[f"{metric_name}_obfuscation/{mode}/epoch"] = (
                metric.compute().item()
            )
            metric.reset()
        if self.train_val_metrics[mode].perplexity.update_called:
            metrics[f"perplexity/{mode}/epoch"] = (
                self.train_val_metrics[mode].perplexity.compute().item()
            )
            self.train_val_metrics[mode].perplexity.reset()
    def on_train_epoch_end(self) -> None:
        """Compute and log the epoch-level training metrics."""
        metrics: dict[str, Any] = {}
        self._on_train_epoch_end(mode="train", metrics=metrics)
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(metrics, step=self.current_epoch)
    def on_validation_epoch_start(self) -> None:
        """Ensure that the model is fully-loaded prior to starting validation."""
        self.noisy_model.restore_and_load()
    def validation_step(
        self,
        batch: tuple[
            noise_tokenizer.NoiseEncoding, noise_tokenizer.NoiseEncoding
        ],
        batch_idx: int,
    ) -> None:
        """Compute the validation distillation loss and perform generation."""
        val_batch, generation_batch = batch
        mode = "valid"
        assert self.trainer.val_dataloaders is not None
        val_dataloader = self.trainer.val_dataloaders[0]
        num_batches = self.trainer.num_val_batches[0]
        step = step = int(self.current_epoch * num_batches + batch_idx)
        metrics: dict[str, float] = {}
        _ = self._training_step(
            val_batch,
            batch_idx,
            val_dataloader,
            num_batches=num_batches,
            mode=mode,
            metrics=metrics,
        )
        if (
            len(self._val_metrics.generated_text_table)
            > self.max_generation_examples
        ):
            if self.logger is not None and self.trainer.is_global_zero:
                self.logger.log_metrics(metrics, step=step)
            return
        input_ids = generation_batch["input_ids"]
        attention_mask = generation_batch["attention_mask"]
        noise_mask = generation_batch["noise_mask"]
        input_text = TRAIN_TOKENIZER.batch_decode(
            input_ids, skip_special_tokens=True
        )
        sequence_length = input_ids.shape[-1]
        generation_config = (
            huggingface.generation.StainedGlassGenerationConfig.from_tokenizer(
                TRAIN_TOKENIZER, max_length=MAX_LENGTH
            )
        )
        if (
            not self.is_generation_dataloader_shuffled
            and batch_idx in self.base_model_generation_cache
        ):
            generated_text = self.base_model_generation_cache[batch_idx]
        else:
            with torch.random.fork_rng(
                devices=[self.device] if self.device.type != "cpu" else [],
                device_type=self.device.type,
            ):
                if self.noisy_model.noise_layer._generators is not None:
                    torch.manual_seed(
                        self.noisy_model.noise_layer.initial_seed()
                    )
                generated_ids = self.noisy_model.base_model.generate(
                    inputs=input_ids,
                    generation_config=generation_config,
                    attention_mask=attention_mask,
                    use_cache=True,
                )
            generated_text = TRAIN_TOKENIZER.batch_decode(
                generated_ids[:, sequence_length:], skip_special_tokens=True
            )
            self.base_model_generation_cache[batch_idx] = generated_text
        generated_ids_from_transformed_embeddings, transformed_embeddings = (
            self.noisy_model.generate(
                inputs=input_ids,
                generation_config=generation_config,
                attention_mask=attention_mask,
                use_cache=True,
                noise_mask=noise_mask,
                return_transformed_embeddings=True,
            )
        )
        generated_text_from_transformed_embeddings = (
            TRAIN_TOKENIZER.batch_decode(
                generated_ids_from_transformed_embeddings[:, sequence_length:],
                skip_special_tokens=True,
            )
        )
        reconstructed_input_ids = (
            self.noisy_model.reconstruct_ids_from_embeddings(
                transformed_embeddings
            )
        )
        reconstructed_input_text = TRAIN_TOKENIZER.batch_decode(
            reconstructed_input_ids, skip_special_tokens=True
        )
        percentage_changed_input_ids = sg_metrics.percentage_changed_ids(
            input_ids, reconstructed_input_ids, noise_mask
        )
        rouge_scores: dict[str, torch.Tensor] = self._val_metrics.rouge(
            generated_text_from_transformed_embeddings, generated_text
        )
        metrics.update(
            {
                f"{metric_name}/{mode}/batch": value.item()
                for metric_name, value in rouge_scores.items()
                if metric_name.endswith("_fmeasure")
            }
        )
        self._val_metrics.generated_text_table.update(
            {
                "input_text": input_text,
                "reconstructed_input_text": reconstructed_input_text,
                "obfuscation_score": [
                    f"{score.item() * 100:0.1f}%"
                    for score in percentage_changed_input_ids
                ],
                "generated_text": generated_text,
                "generated_text_from_transformed_embeddings": generated_text_from_transformed_embeddings,
            }
        )
        if self.logger is not None and self.trainer.is_global_zero:
            self.logger.log_metrics(metrics, step=step)
    def on_validation_epoch_end(self) -> None:
        """Compute and log the epoch-level validation metrics."""
        step = self.current_epoch
        mode = "valid"
        metrics: dict[str, Any] = {}
        self._on_train_epoch_end(mode=mode, metrics=metrics)
        observations, obfuscation_quantiles = (
            self._val_metrics.obfuscation_quantiles.compute()
        )
        metrics.update(
            {
                f"obfuscation_quantile/{q * 100:.0f}%/{mode}/epoch": value
                for q, value in obfuscation_quantiles.items()
            }
        )
        self._val_metrics.obfuscation_quantiles.reset()
        observed_obfuscation_scores = (
            self._val_metrics.obfuscation_scores_cat.compute()
        )
        self._val_metrics.obfuscation_scores_cat.reset()
        generated_text_table = self._val_metrics.generated_text_table.compute()
        self._val_metrics.generated_text_table.reset()
        if self.trainer.is_global_zero:
            if self.tb_writer is not None:
                self.tb_writer.add_histogram(
                    f"obfuscation_histogram/{mode}/epoch",
                    values=observed_obfuscation_scores.to("cpu", torch.float32),
                    global_step=step,
                    bins=torch.linspace(0.0, 1.0, 513),  # pyright: ignore[reportArgumentType]
                )
                self.tb_writer.add_text(
                    f"generated_text/{mode}/epoch",
                    text_string=sg_utils.torch.tensorboard.to_markdown_table(
                        generated_text_table
                    ),
                    global_step=step,
                )
            if self.wandb_logger is not None:
                self.wandb_logger.log_text(
                    f"generated_text/{self.current_epoch}/{mode}/epoch",
                    columns=list(generated_text_table.keys()),
                    data=list(zip(*generated_text_table.values())),
                )
            hyperparameters = self.get_hyperparameters()
            if self.wandb_run is not None:
                self.wandb_run.config.update(hyperparameters)
            if self.logger is not None:
                self.logger.log_metrics(metrics, step=step)
                # TODO: log_hyperparams doesn't currently accept a step and so the valid/epoch metrics aren't logged correctly
                # Fixed by: https://github.com/Lightning-AI/pytorch-lightning/pull/20176
                # self.logger.log_hyperparams(hyperparameters, metrics=metrics, step=step)
In [ ]:
Copied!
import lightning.pytorch.loggers
import wandb
import stainedglass_core
import lightning.pytorch.loggers
import wandb
import stainedglass_core
In [ ]:
Copied!
WANDB_PROJECT: Final[str] = "llm-transform-training-notebook"
SAVE_DIR: Final[str] = "saved/"
LOG_STEP: Final[int] = 100
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_float32_matmul_precision("high")
ACCELERATOR: Final[str] = "cuda"
WANDB_PROJECT: Final[str] = "llm-transform-training-notebook"
SAVE_DIR: Final[str] = "saved/"
LOG_STEP: Final[int] = 100
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_float32_matmul_precision("high")
ACCELERATOR: Final[str] = "cuda"
In [ ]:
Copied!
tb_logger = lightning.pytorch.loggers.TensorBoardLogger(
    SAVE_DIR,
    name="mistral_distillation",
    default_hp_metric=False,
)
wandb.login()
wandb_logger = lightning.pytorch.loggers.WandbLogger(
    project=WANDB_PROJECT,
    save_dir=tb_logger.log_dir,
    log_model=True,
    sync_tensorboard=True,
    tags=[
        f"stainedglass_core=={stainedglass_core.__version__}",
    ],
)
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
_ = wandb_logger.experiment
trainer = lightning.Trainer(
    max_epochs=1,
    accumulate_grad_batches=36,
    accelerator=ACCELERATOR,
    strategy="auto",
    precision="bf16-true",
    log_every_n_steps=LOG_STEP,
    default_root_dir=SAVE_DIR,
    logger=[
        tb_logger,
        wandb_logger,
    ],
)
if trainer.strategy._precision_plugin is not None:
    trainer.strategy._precision_plugin = sg_lightning.ReducedPrecisionFilter(
        trainer.strategy._precision_plugin,
        full_precision_module_types=(sg_torchmetrics.QuantileMetric,),
    )
with trainer.init_module():
    distillation_module = StainedGlassDistillationLightningModule(
        max_generation_examples=320,
        obfuscation_log_step=LOG_STEP,
        noise_component_histogram_log_step=LOG_STEP,
    )
trainer.fit(model=distillation_module)
tb_logger = lightning.pytorch.loggers.TensorBoardLogger(
    SAVE_DIR,
    name="mistral_distillation",
    default_hp_metric=False,
)
wandb.login()
wandb_logger = lightning.pytorch.loggers.WandbLogger(
    project=WANDB_PROJECT,
    save_dir=tb_logger.log_dir,
    log_model=True,
    sync_tensorboard=True,
    tags=[
        f"stainedglass_core=={stainedglass_core.__version__}",
    ],
)
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
_ = wandb_logger.experiment
trainer = lightning.Trainer(
    max_epochs=1,
    accumulate_grad_batches=36,
    accelerator=ACCELERATOR,
    strategy="auto",
    precision="bf16-true",
    log_every_n_steps=LOG_STEP,
    default_root_dir=SAVE_DIR,
    logger=[
        tb_logger,
        wandb_logger,
    ],
)
if trainer.strategy._precision_plugin is not None:
    trainer.strategy._precision_plugin = sg_lightning.ReducedPrecisionFilter(
        trainer.strategy._precision_plugin,
        full_precision_module_types=(sg_torchmetrics.QuantileMetric,),
    )
with trainer.init_module():
    distillation_module = StainedGlassDistillationLightningModule(
        max_generation_examples=320,
        obfuscation_log_step=LOG_STEP,
        noise_component_histogram_log_step=LOG_STEP,
    )
trainer.fit(model=distillation_module)