Recommendations for Training a Stained Glass Transform for a Causal Language Model¶
Training a Stained Glass Transform (SGT) with Stained Glass Engine (SGE) for large language models involves various details that can significantly impact the quality of your SGT. This article seeks to describe many of these in a comprehensive way to act as a recommendation guide to users who plan on training a SGT for a generative decoder only large language model such as Llama2, Llama3, or Mistral models.
Staying Organized¶
Training a Stained Glass Transform is training a deep learning model. As a result the usual recommendations for systematic organization exist. Stained Glass Transforms created with SGE can be produced either by a python file or Jupyter notebook. In either case, having a unique file that is clearly named helps users and scientists know the intent of the file.
Stained Glass Engine is compatible with experiment tracking like Weights and Biases (WandB) or tensorboard. If these tools are available, use them to track the creation of the Stained Glass Transform and the resultant artifacts.
SGT Arguments and Hyperparameters¶
A Stained Glass Transform has many parameters defining it which directly effect SGE’s ability to produce a useful transformation. The purpose and recommendations for the hyperparameters are summarized.
-
scale
: The range of allowed strengths of randomness of the SGT. The scale is a tuple of two numbers of the form(min_scale, max_scale)
. Settingmin_scale
andmax_scale
appropriately can have dramatic effects on the quality of the mode. Generallymin_scale
is recommended to be approximately 10 times less than the minimum value of all features andmax_scale
should be 10x more than the maximum value of all features. One can check the scales of the embeddings by simply examining the features of every token embedding. For Hugging Face transformers this can easily be computed by accessing the weights of the embedding module at the beginning of the LLM. If themin_scale
is too small, the applied transformation may not be random enough to produce sufficiently stochastic embeddings. Ifmin_scale
is too large too much noise may be applied causing immediate divergence during the training of the SGT. Ifmax_scale
is too small, there may be more of an error budget for features available allowing for stronger transforms to be learned while maintaining utility, and ifmax_scale
is too large you may incur a scale separation in your data which can cause a decrease in learnability of the SGT and hurt the utility of the transformed data. In other words too large of a max scale hurts the utility in the privacy-utility tradeoff. In the case of the Llama2 and mistral embeddings we have found that two cases for scale exist depending on the value ofdirectly_learn_stds
. Whendirectly_learn_stds
is True, a scale of (1e-8, 1.0) is a good starting recommendation and whendirectly_learn_stds
is False, a scale of (1e-3, 1.0) is a good initial point to sweep. -
transformer_type
: The TransformerCloak SGT uses transformers to model the sequence dependent stochastic re-representations of the input embeddings. In principle any transformer block could be used, but it is strongly encouraged to use a transformer of the same type as the model for which Stained Glass Transform is being generated. For example when training an SGT for Mistral 7B Instruct V 0.2, thetransformer_type
istransformers.MistralModel
since allMistralModels
share the same type. -
config_path
: We currently require that the transformers underlying TransformerCloak exist on disk prior to the training of a SGT. Theconfig_path
should be the path to the directory where the transformer and its associated files are saved. This could look something like/models/huggingface/mistralai/Mistral-7B-Instruct-v0.2
. If the transformer_type is not compatible with the transformer being specified byconfig_path
and exception will be raised. -
percent_to_mask
: An experimental feature which replaces a procedurally determined fixed percentage of features (valued between 0 and 1 inclusive) with features learned from another deterministic transformation. This setting should remain at 0.0. Using a non-zero value in [0,1], while possible, is not recommended. Choosing other values may affect convergence during training of the SGT. -
shallow
: A fixed temperature like parameter which alters the scale of the standard deviation of the noise whendirectly_learn_stds
isFalse
. Using values other than 1.0 is highly experimental and not recommended. Other values may affect convergence during training of the SGT. -
seed
: Seed for the random number generator used to generate noise. For debugging this can be any integer for reproducibility. For all other cases it should beNone
. Non-None
values should not be considered secure. -
rho_init
: Whendirectly_learn_stds
is False, this value controls the starting point of the parameterized optimization. A value of -3.0 is recommended in this case as it balances the distance to the min scale and the rate of convergence of the optimization. Other values in this case will either slow down the optimization or not respect the min_scale value. Whendirectly_learn_stds
isTrue
this argument must be 0.0 and all other values will cause immediate divergence of the transform. -
std_dropout
andmean_dropout
: The SGT uses two transformer blocks in its internals, one responsible for the stochastic component of the transform and one responsible for the deterministic component. The dropout of the final layer of these blocks is exposed as a hyperparameter. It is recommended to keep this at 0.1. Other values could be used depending on the dropout of the base model for which the SGT is being learned. When the model is placed in PyTorch’s eval mode, the dropout layers are deactivated becoming identity layers instead. -
directly_learn_stds
: There are two parameterizations of the SGT which may be selected depending on the value of this boolean variable. Whendirectly_learn_stds
isFalse
a parameterization of the transformation is learned which uses a hidden tensor of parameters, known as rho, which is used to control the learned strength of the stochastic component of the transformation. Whendirectly_learn_stds
isTrue
no hidden variables are used and the SGT directly learns parameters which act on the embeddings directly. Usingdirectly_learn_stds
eitherTrue
orFalse
can both lead to good models, however the values of scale and rho_init should be chosen appropriately in these cases as outlined above. It is recommended to setdirectly_learn_stds
toTrue
to begin with and only look at other parameterizations ifdirectly_learn_stds
produces transformations which are diverging immediately and changing the other hyper-parameters does not fix the divergence. -
mean_num_experts
andstd_num_experts
: The SGT optionally supports the use of mixture of experts. It is recommended that these values are 0 unless you observe that you have more data and data of wider modality which exceeds the learning capability of the SGT. Note that turning these values to positive integers requires DeepSpeed to be installed and significantly increases the VRAM usage of the model. -
use_causal_mask
: By designTransformerCloak
uses the underlying structure of the decoder transformer to define the structure of the SGT. This results in causal masks being used in learning of the SGT. While this does keep the semantic information of the SGT similar to that of the base model, non-causal masks in the internal model of the SGT are more natural when learning sequence to sequence maps. The SGT has excellent performance when usinguse_causal_mask
eitherTrue
orFalse
. It is recommended to start withuse_causal_mask
to beTrue
. If the SGT seems to be learning to slowly, but larger learning rates cause divergences (even with LR scheduling), then settinguse_causal_mask
toFalse
may allow for richer distributions to be learned. -
**kwargs
: TheTransformerCloak
SGT uses Hugging Face Transformers as constituent components. These variadic keyword arguments passed to the constructor ofTransformerCloak
are themselves passed onto the constructors of those transformer models comprisingTransformerCloak
.
Attention Mechanisms¶
Attention computations in transformers is \(O(n^2)\) in memory and time-complexity by default when eager attention is used where n is the sequence length of the input tokens. Significant improvements to the attention mechanism in transformers have occurred of which Flash Attention 2 (and now Flash Attention 3) are the most significant during training. Flash attention 2 and its successors techniques significantly reduce the memory occupancy and time-complexity of the attention computation. Moreover they are not approximations; they are provably exact algorithms. They require that the model parameters are in reduced precision (bfloat16 for example). While using Flash Attention in the base model is highly encouraged, because the SGT is run in single precision, use of Flash Attention in the SGT layer is not currently possible.
Use of Flash Attention 2 is natively supported in both Pytorch and Hugging Face transformers. See Dao-AILab/flash-attention for more details.
Dataset Recommendations¶
Training a SGT for an LLM is akin to pre-training a smaller LLM. Ideally the data used for SGT training should be a subset of the training data used to train the LLM itself. As this is not always possible, a medium sized text dataset which is in-distribution to the data which the SGT will be use on at inference time is sufficient. At Protopia, we often use an aligned version of OpenOrca for training a SGT at scale. Direct use of OpenOrca itself is also recommended.
When setting up and testing the training loop prior to training the SGT a smaller dataset is useful do decrease the training and experimentation time, reducing the time-to-value for Stained Glass Engine. Subsets of this large dataset can be used, or medium to small datasets such as Alpaca fill this need very well.
If Hugging Face Datasets are being used and cached, make sure to clear or delete the cache directory if updates were made to the format of the data or the data itself during pre-processing. Using previous caches can lead to subtle bugs which are difficult to track down.
Dataloading Recommendations¶
Dataloading and collation of the dataset is necessary when using large datasets and training language models as the data itself will not fit into memory on the GPU. Moreover, it is dataloading which can often be the performance bottleneck when training deep neural nets. This can be mitigated by using sufficiently many dataloading processes (typically at least 4 dataloading processes per node), but the exact value is system and dataset dependent.
Collating is typically done for LLMs in one of two ways: Pad to the max context length or pad to the longest example in the batch. This is managed by the Stained Glass Transform Engine which provides collators allowing for the variable length padding to the longest example in a batch. This can cause memory fragmentation during training time resulting in significant allocated by unusable memory. This can be mitigated in recent versions of Pytorch with the experimental environment variable PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
.
Optimizer Recommendations¶
In general when training a SGT, it is good practice to use the same optimization settings that were used to train the base foundational model. For LLMs it is common practice to use AdamW. We recommend using the modern Transformer AdamW settings for training a Stained Glass Transform:
-
betas = [0.9, 0,95]
andeps=1e-5
: A good starting place for choosing a learning rate is the final learning rate used to train the base model for which the stained Glass Transform is being learned. For Llama2 for example this would be 3e-5. Llama2 also used a weight decay of 0.1 for all the model weights. -
weight_decay
: We follow the recommendation that weight decay hyperparameters are the same as the ones used during pre-training with one critical exception. The weight decay for the final linear layers of thestd_estimator
andmean_estimator
modules of the SGT must be 0.0. If non-zero values are used this will significantly hamper the ability of the model to learn meaningful transformations quickly.
The SGT Loss Function¶
Training a SGT with SGE for an LLM uses a three component loss function balancing utility, semantic difference, and stochasticity of the model. Let \(L_M\) be a loss modeling the quality of the base model performance (this can either be the language modeling loss, or a self-supervised loss), \(L_S\) be a loss modeling the semantic similarity between tensors, and \(L_N\) a loss modeling the stochasticity of the transformation. Further let \(\alpha \in [0,1]\) and \(\beta, \gamma, w \in \mathbb{R}^+\). Then the SGT loss is generically formed as
\((1-\alpha)wL_M + \alpha(\beta L_N + \gamma L_S)\)
The choices of \(\alpha, \beta, \gamma\) and \(w\) all dramatically affect the creation of the SGT.
In SGE we expose these hyperparameter under the more descriptive names:
- \(\alpha\) is
alpha
- \(\beta\) is
transform_variance_priority
- \(\gamma\) is
transform_strength_priority
- \(w\) is
weight
We recommend the following choices as reasonable starting points for the optimization:
alpha = 0.54
transform_variance_priority = 0.01
transform_strength_priority = 0.75
weight=12.0
.
Metrics¶
Training a Stained Glass Transform require tracking metrics like training any other model. The metrics can be broken down into two primary categories: forward based metrics and generation based metrics.
Forward Based Metrics¶
Forward based metrics are defined as functions of the logits predicted by the language model. Examples of these are the losses computed during training and perplexity. These forward based metrics should be computed on both the training and a validation set in the loop during training. When computing perplexity, it is recommended to set the ignore index so that padding tokens are not considered to contribute to the perplexity score.
Forward based test metrics also exist and are common parts of Eluther AI’s evaluation harness. Examples of these metrics are multiple choice test problems such as the MMLU tasks. Forward based metrics are independent of the generation parameters, since these parameters can effect the form and quality of the generated text. Protopia AI currently offers adapters into Eluther AI’s evaluation harness which supports dozens of common language models tests. It is recommended that you choose test metrics that are relevant to the intended use of the LLM to evaluate the model’s performance with SGT. In general the test metrics used to evaluate the model without the SGT should be a subset of the metrics used to test the model with the SGT.
Forward based metrics also contain the subcategory of obfuscation metrics. These metrics measure the ability of the SGT to transform text to a different text representation. We currently offer an API in SGE to compute the tokens in the vocabulary whose embeddings are closest to the transformed embedding. We then can calculate the percentage of tokens which are transformed in a given sequence of tokens. We recommend looking at the mean, min, and quantiles of this distribution over a set of validation and test examples in order to understand the obfuscation performance of the transformation.
Generation Based Metrics¶
Generation based metrics are metrics which evaluate qualities of the tokens and/or text generated by language models. These metrics depend on the generation configuration settings and the sampling algorithm used, thus it is recommended to set these with care according to similar settings for the tasks and use case of the base models being trained. Our specific recommendations for these settings will be covered in a subsequent section. The metrics we recommend for measuring the quality of generation involve comparing the generated text to known results. This can be custom datasets of known labels generated by the original base model without the Stained Glass Transformation present.
The particular metrics for this case are known as Rouge metrics which compute the accuracy, precision, recall, and F-metrics of known N-grams of tokens between the candidate (transformed) and ground truth labels. These metrics become less reliable when the output of the generated text of the base model is not in distribution with the labels. In this base BertScores of the generated text as well as automated grader pipelines where a third LLM (such as Chat GPT and similar LLMs) grade the quality of the responses generated with and without the transformation present.
Training Strategies¶
Using SGE to create a SGT is like training any other machine learning model. The first step is to create a python application which trains a model and saves an artifact. When testing for functionality of the pipeline using SGE it is recommended to use a small dataset to of a few samples demonstrate the code runs without error. This same dataset can then be used to demonstrate overfitting of the model. Hyper-parameter sweeps may be useful to perform on a small to medium sized dataset on the order of 10,000 examples (such as Alpaca) to reduce the time during this phase of the machine learning lifecycle. When an optimal set of hyper-parameters is discovered training should be performed on a larger dataset on the order of millions of examples.
When training a larger dataset, often convergence can happen as early as epoch 5. Unfortunately, this can often take several days even when using significant amounts of hardware. As a result it is recommended that frequent time-based checkpointing occur within an epoch to prevent significant losses if a system crash were to occur.
Moreover it would also be useful to have intra-epoch validation for these larger datasets, as the amount of validation data can be limited and the time to first appearance can be significant otherwise.
Other Hyperparameters¶
Gradient accumulation is a hyperparameter which both affects the convergence of a model on the optimization level and also the wall-clock time of distributed models to converge to their final value. These objectives are not necessarily contradicting, so it is recommended to find a sweet spot that satisfies both of these objectives. However, to gain the performance boosts when using gradient accumulation, make sure your ML framework disables the communication of the gradient updates until the step will actually perform the update. Otherwise, the performance boosts of gradient accumulation will not exist and only the theoretical balance between exploration and exploitation in the non-linear optimization will be effected. In the case of the SGT, since the transformation itself is stochastic, we find that we can have significant gradient accumulations (and hence batch sizes) without sacrificing too much in the exploration/exploitation trade.
When training the SGT for LLMs we have validated that training where the base model is cast to bf16 is possible, but the SGT itself must always be in single precision (float32). Stained Glass Engine manages and asserts that the transformation and the data which flows into it is all in single precision. If the user attempts to downcast or quantize these layers after the fact errors will be thrown crashing at the time of the SGT computation.
Model Truncation during Self-Supervised Training¶
Self-supervised training is the recommended way to produce a SGT for a LLM. In this training paradigm the model loss is replaced by a loss which is a function of the hidden state of an intermediate layer of the transformer acting as the language model for the LLM. All activations which are in the computational graph of the forward call of the model do not contribute to the loss during SSL training and thus have no impact on the learned SGT.
SGE provides an API for truncating a model and unloading the unused layers from memory truncate_and_offload
allowing for significant GPU memory saving by offloading unused model parameters to the CPU during training and removing them from the computational graph. This allows for larger batch sizes to be used during training accelerating the time to convergence when training a SGT.
The layer at which this truncation, and hence the SSL loss is evaluated is known as the truncation layer index and is itself a hyperparameter to be discovered. For 32 layer models such as Llama2 and Mistral 7B we recommend having a truncation layer index of 12 (so there are 13 base model layers layer loaded into GPU memory).
Histograms of standard deviations and other parameters¶
Experiment tracking solutions such as WandB allow for logging of hyperparameters to runs and experiments. They also allow for logging of more complex data such as images, text, audio, and more. We highly recommend logging text tables of generated text and reconstructions of transformed embeddings to text at every validation step in order to observe the convergence and the performance of the SGT.
Forward Validation vs. Generation Validation¶
When training a SGT for an LLM, validation can be broken down into two parts: forward validation and generation validation. The reason we need to distinguish between these cases is a result of how the training an SGT is akin to pre-training a language model. Pre-training is closer to a form of unsupervised training since the labels of the training data are included in the training data itself. As a result there is no need for a labels component at all when training an SGT. The forward based metrics outlined above computed on the validation dataset can then use this same dataloader with the validation dataset in place of the training dataset. The generation based metrics however do require ground truths to compare against. These ground truths can come in the form of labels, or more ideally generations of the base model (without the Stained Glass Transformation). These base model generations can either be pre-computed or cached and should only need to be computed once. This necessitates a different dataloader which provides labels together with validation examples. It is recommended to then have a two-stage validation process where both the forward based metrics and the generation metrics are computed on the validation dataset each with their respective dataloaders.
Double Check What is Being Computed¶
Mistakes happen in dependent libraries, some of which can be critical. Some implementations of the language modeling loss for the Llama2 model have been historically incorrect due to missing shifts in the labels (mathematically computing \(\mathbb{P}(x_n| x_n, … x_1)\) instead of \(\mathbb{P}(x_{n+1} | x_n, … , x_1)\)). While its not possible to validate all dependencies, for critical components, like the loss function, it is highly recommended to double check the implementation. Another example spot to check is the chat template being used as this can dramatically affect model efficacy.
Tokenizer Settings¶
The tokenizer is responsible for taking strings and transforming them into sequence of integers. These integers are referred to as tokens. The set of tokens themselves are referred to as the vocabulary and the number of tokens is the vocabulary size. There are optional special tokens in the vocabulary which are used to encode special meaning for the language model. The most common of these optional special tokens are the beginning of string token, the end of string token, and the pad token. In practice these special tokens can all map to the same token in the vocabulary. For example in Llama2 and Mistral, it is convention to set the pad token to be the end of string token. Functionally the choice of what the pad token is does not matter since the attention mask removes it from consideration during all computations.
Another important setting for the tokenizer is to define which side the padding occurs on. While during training padding can either occur on the right or the left, during generation Hugging Face requires padding to occur on the left. This is because in left-to-right written languages such as English it is easy uniformly append all the generated tokens as a new column at the end of the tensor and drop the first column, while for right padded on would have to figure out where the end of the sentence is for each example. As a result it is recommended that the tokenizer should always set padding on the left when dealing with LLM’s so that the tokenizer is never misconfigured during training.
Generation Settings¶
Generation settings dramatically impact the quality of the generated text. The first and most important setting to decide is the sampling algorithm itself. For all but the most basic debugging or forward based testing it is recommended that a non-deterministic beam-search of with a single beam is chosen. In order to ensure that this algorithm is selected when using a Hugging Face generate call ensure that do_sample
is True
and that num_return_sequences
is 1.
The beam search algorithm itself is parameterized by several parameters which affect the generation:
-
max_new_tokens/max_tokens
: Defines the maximum number of tokens generated (or maximum number of tokens in the original prompt plus generated tokens) for an example during the beam search. When this threshold is hit the generation algorithm halts early. The generation will also end early if the end of string token is generated for the example. Generally you don’t want to generate more tokens than the context length of your model allows. -
top_p
: This parameter controls the amount of probabilistic mass that is used to sample subsequent tokens. Imagine that only 10 tokens contain 90% of the total probability of being the next token. Then atop_p
value of 0.9 would ensure that these only one of these next 10 tokens would be selectable during the sampling. Each model will recommend different values oftop_p
that they found to be successful. Currently we recommend a value of 0.9 for Llama2 and Mistral models. This is a natural starting recommendation, however it is worth considering exposing this value to users to set themselves for their specific application. -
top_k
: This parameter controls the size of the pool of tokens from which the sampling algorithm is allowed to draw from. Iftop_k
is 10, then only the 10 most likely tokens are considered for sampling to generate the next token. We recommend using a setting of 5000 for Llama2 and Mistral models as a starting point, however this parameter should be exposed to users to tune for their own use cases and applications. -
temperature
: The actual output of the language model is not conditional probabilities, but rather logits representing the conditional probabilities. The softmax function takes logits to probabilities, but may be optionally parameterized by a real valuedtemperature
parameter which spreads out or concentrates the probability distribution around its natural peaks and valleys. Temperature values less than 1.0 concentrate the probability distribution around its most likely predictions making the generations more conservative and terse. Temperatures greater than 1.0 spread out the probabilistic mass making it so that tokens which were previously less likely are not more likely. This can be good when creativity is preferred. Models which are meant to be factual should use lowtemperature
values, as with starting recommendations as low as 0.3 models summarizing medical insurance data for example. Some applications may even benefit from lowertemperature
than that. By default we set thetemperature
for Llama2 and Mistral models to be 0.6, however we recommend exposing this parameter to users as it is completely application dependent. -
repetition_penalty
: Sampling algorithms risk getting into loops when probabilities are are concentrated around a few tokens. The repetition penalty parameter in the beam search algorithm penalizes the sampling algorithms likelihood of repeating itself allowing for more natural responses. By default we recommend using a value of 1.0. -
renormalize_logits
: The above arguments transform the probability distribution in a way such that the transformed logits no longer necessarily form a probability distribution. As such it is necessary to renormalize the logits such that after the settings above transform the distribution the resulting data does not cause the sampling algorithms to break down or sample tokens incorrectly. Therenormalize_logits
parameter should be set to True always to ensure the resultant distributions always produce probability distributions. For historical reasons this parameter does not default to True in Hugging Face, so it must be manually enabled. -
pad_token_id/eos_token_id/bos_token_id
: A priori the sampling algorithm does not know what the token ids for the pad, end of string, and beginning of string tokens are. Passing this to the tokenizer, especially the end of string token, is necessary in order to guarantee that the generations terminate correctly and are batched correctly together. Failing to set theeos_token_id
in particular can cause the model to fail to ever stop talking resulting in nonsensical the generation producing text which trails off into nonsense.
Things managed by SGE¶
Stained Glass Engine manages the handling and flow of data to and from the SGT as well and how implementation details of the the modifications to the optimization procedure. In particular once Stained Glass Engine and the SGT is instantiated and attached to the LLM the transform is called on the embeddings every time the base model is called.