Deploying Stained Glass Transform for Text with ExecuTorch 📱¶
This tutorial demonstrates how to export a pretrained Stained Glass Transform for Text (SGT4T)
to ExecuTorch, producing a quantized .pte file that
can run on mobile platforms using the XNNPACK backend
delegate for accelerated CPU inference and reduced model size.
The notebook walks through the following steps:
- Load & run the SGT4T: Load a pretrained SGT4T from the Hugging Face Hub and run example inference on CPU.
- Prepare for export: Wrap the SGT4T in
ExportableStainedGlassTransformForTextand inspect model metadata. - Quantize and lower with XNNPACK: Apply Linear quantization via
stainedglass_core.integrations.torchao, export withtorch.export, and lower with the XNNPACK backend delegate. - Save, load, and execute: Serialize to a
.ptefile, load into the ExecuTorch runtime, and run forward passes.
Prerequisites:
- stainedglass_core with executorch and torchao integration extras.
- SGT: Protopia AI SGT for Llama-3.1 8B Instruct
import gc
import shutil
import tempfile
from pathlib import Path
import torch
from executorch import exir, runtime
from executorch.backends.xnnpack import partition
from stainedglass_core import transform as sg_transform
from stainedglass_core.integrations import (
executorch as sg_executorch,
torchao as sg_torchao,
)
1. Load and run the SGT4T for Inference on CPU¶
MODEL_ID = "Protopia/SGT-for-Llama-3.1-8B-Instruct-stoic-firefly-4"
DEVICE = torch.device("cpu")
Load the pretrained SGT4T from the Hugging Face Hub.¶
sgt4t = (
sg_transform.StainedGlassTransformForText.from_pretrained(
MODEL_ID, noise_layer_attention="sdpa"
)
.to(DEVICE)
.eval()
)
Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 37755.60it/s]
Run inference¶
The SGT4T accepts chat-formatted input and returns transformed embeddings. These embeddings are what the client would send to the model provider instead of raw text.
💡 The Stained Glass Transform protects the input while preserving the model’s ability to generate accurate responses.
messages = [
{
"role": "user",
"content": "The Stained Glass Transform protects the input while preserving the model's ability to generate accurate responses.",
}
]
with torch.no_grad():
transformed_embeddings = sgt4t(messages)
print(f"Output shape: {transformed_embeddings.shape}")
print(f"Output dtype: {transformed_embeddings.dtype}")
print(f"Sample embedding values: {transformed_embeddings}")
Output shape: torch.Size([1, 54, 4096])
Output dtype: torch.bfloat16
Sample embedding values: tensor([[[ 2.6512e-04, -4.9973e-04, -5.8365e-04, ..., 3.8147e-03,
6.3419e-05, 1.1902e-03],
[-1.6499e-04, -2.4319e-04, 1.6403e-04, ..., -1.5163e-04,
3.5095e-04, 7.3242e-04],
[ 8.1787e-03, 2.5749e-04, -3.4637e-03, ..., 5.3024e-04,
5.6152e-03, 1.6113e-02],
...,
[-9.7656e-03, -3.4637e-03, 1.8616e-03, ..., -7.1411e-03,
-4.3030e-03, 8.6060e-03],
[-4.6158e-04, -3.9291e-04, -6.5863e-06, ..., -6.2561e-04,
-5.0354e-04, 6.6757e-04],
[-2.8687e-03, 3.8910e-03, -1.7357e-04, ..., 8.0872e-04,
5.0354e-04, 2.3041e-03]]], dtype=torch.bfloat16)
2. Prepare for Export¶
To deploy an SGT4T to the ExecuTorch runtime we need to convert it into a
torch.export.ExportedProgram.
The ExportableStainedGlassTransformForText wrapper handles this.
Wrap the SGT4T for export¶
exportable = sg_executorch.ExportableStainedGlassTransformForText(sgt4t)
Inspect model metadata¶
ExportableStainedGlassTransformForText produces a metadata dictionary
that gets embedded into the .pte file as constant methods. This includes
base model architecture info as well as SGT-specific entries.
metadata = exportable.get_metadata()
metadata
WARNING:stainedglass_core.integrations.executorch.exportable_text_sgt:The default max dynamic sequence length derived from the model config is 131072. ExecuTorch's runtime uses upper-bound memory planning, which pre-allocates buffers for the maximum sequence length at load time. Large values (e.g. 131072 for Llama 3.1) can cause the runtime to OOM. Consider passing a smaller `dynamic_shapes` with a `torch.export.Dim` capped to your actual deployment needs, e.g.:
seq_dim = torch.export.Dim('seq_len', min=2, max=2048)
dynamic_shapes = {'input_ids': {1: seq_dim}, 'noise_mask': {1: seq_dim}, 'attention_mask': {1: seq_dim}}
exportable.export(dynamic_shapes=dynamic_shapes)
`torch_dtype` is deprecated! Use `dtype` instead!
{'get_dtype': 6,
'get_bos_id': 128000,
'get_eos_id': [128001, 128008, 128009],
'get_head_dim': 128.0,
'get_n_kv_heads': 8,
'get_n_layers': 1,
'get_vocab_size': 128256,
'get_max_batch_size': 1,
'get_max_seq_len': 131072,
'use_sdpa_with_kv_cache': False,
'enable_dynamic_shape': True,
'get_stainedglass_core_version': '3.23.0',
'get_hidden_size': 4096,
'get_sgt_name': 'stoic-firefly-4',
'get_model_name': 'Stained Glass Transform (stoic-firefly-4) for Unknown Base Model'}
SGT model size vs. full base model¶
The SGT contains only the noise layers, not the full base LLM weights.
💡 This is a key advantage of deploying the SGT client on-device: the client never needs the full LLM weights.
def param_size_gb(module: torch.nn.Module) -> float:
"""Calculate the total size of a module's parameters in GB."""
return sum(p.numel() * p.element_size() for p in module.parameters()) / 1e9
full_base_model_params = 8.03e9 # Llama-3.1-8B-Instruct parameter count
full_base_model_bytes_per_param = 2 # bfloat16
full_base_model_size = (
full_base_model_params * full_base_model_bytes_per_param / 1e9
)
exportable_size = param_size_gb(exportable)
exportable_params = sum(p.numel() for p in exportable.parameters())
print(
f"Full Llama 3.1 8B Instruct: {full_base_model_size:.2f} GB ({full_base_model_params / 1e9:.1f}B params)"
)
print(
f"Exported SGT: {exportable_size:.2f} GB ({exportable_params / 1e6:.0f}M params)"
)
print(
f"Ratio: {exportable_size / full_base_model_size:.1%} of the full base model"
)
Full Llama 3.1 8B Instruct: 16.06 GB (8.0B params) Exported SGT: 1.99 GB (995M params) Ratio: 12.4% of the full base model
3. Quantize and Lower with Delegation to XNNPACK Backend¶
XNNPACK is a highly optimized library of neural network inference operators for ARM, x86, and WebAssembly.
We follow the approach used by
optimum-executorch:
Quantize the eager nn.Module before export, then export and lower the already-quantized model.
The stainedglass_core.integrations.torchao module provides quantize_module_ and
build_linear_config for applying XNNPACK-compatible quantization.
Note: Embedding quantization is skipped because the XNNPACK partitioner in ExecuTorch 1.2 does not include a delegation config for
quantized_decomposed::embedding_byte.
The key steps are:
- Convert to float32 — XNNPACK supports float32 and quantized INT8 ops, but not bfloat16. The SGT model is loaded in bfloat16, so we convert to float32 to maximize XNNPACK delegation: without this, only the quantized layers would be delegated while all remaining ops (noise formula, residual adds, etc.) would fall back to the portable runtime.
- Quantize Linear layers with
sg_torchao.quantize_module_: INT8 dynamic activation + INT4 weight withPerGroup(32)granularity. - Export the quantized model with
torch.export. - Lower with XNNPACK using
to_edge_transform_and_lowerwith quantization fusion enabled.
Export settings¶
The ExportableStainedGlassTransformForText.forward method accepts three tensors:
| Argument | Shape | Dtype | Description |
|---|---|---|---|
input_ids |
(batch, seq_len) |
long |
Token IDs |
noise_mask |
(batch, seq_len) |
bool |
True = apply noise, False = keep original |
attention_mask |
(batch, seq_len) |
long |
Transformer attention mask |
The export() method makes the sequence-length dimension dynamic by default,
so the exported program accepts inputs of varying length at runtime. The batch
dimension is fixed at 1 (typical for per-request inference on edge).
We customize two export settings:
-
dynamic_shapes: By default, the max dynamic sequence length is derived from the model config (max_position_embeddings=131072for Llama 3.1). ExecuTorch's portable runtime uses upper-bound memory planning and pre-allocates buffers for the maximum sequence length at program load time. Leaving the default could cause a runtime OOM error, so we capmaxto a value large enough for the demo while keeping memory modest (e.g.1024). Production deployments should pick amaxthat covers their expected input sizes — larger values linearly increase the activation buffers ExecuTorch pre-allocates atload_programtime. See: https://docs.pytorch.org/executorch/main/using-executorch-export.html#supporting-varying-input-sizes-dynamic-shapes -
prefer_deferred_runtime_asserts_over_guards: Llama 3.1's RoPE scaling produces a symbolic guard thattorch.exportcannot statically prove. This flag converts it into a lightweight runtime assertion that is still checked with the concreteseq_lenat execution time.
MAX_SEQ_LEN = 1024
# Convert to float32: XNNPACK supports float32 and quantized INT8 ops but not
# bfloat16. Without this conversion, only the quantized layers would be delegated
# to XNNPACK while all remaining ops fall back to the portable runtime.
exportable = exportable.float()
exportable._noise_layer_dtype = torch.float32
# Quantize Linear layers: INT8 dynamic activation + INT4 weight, PerGroup(32).
# Embedding quantization is skipped: the XNNPACK partitioner in ExecuTorch 1.2
# does not support delegating quantized_decomposed::embedding_byte, and the
# portable runtime lacks this op as a fallback.
sg_torchao.quantize_module_(
exportable,
embedding_config=None,
linear_config=sg_torchao.build_linear_config("int4"),
)
print("Quantized Linear layers (8da4w)")
# Export the quantized model with dynamic sequence length.
seq_dim = torch.export.Dim("seq_len", max=MAX_SEQ_LEN)
dynamic_shapes = {
"input_ids": {1: seq_dim},
"noise_mask": {1: seq_dim},
"attention_mask": {1: seq_dim},
}
exported_program = exportable.export(
dynamic_shapes=dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=True,
)
print(f"\u2705 Export OK: {type(exported_program).__name__}")
Quantized Linear layers (8da4w) ✅ Export OK: ExportedProgram
Lower with XNNPACK¶
Note on
_check_ir_validity=False: HuggingFace transformers constructs causal attention masks usingtorch.vmap(inmasking_utils._vmap_for_bhqkv). This traces into ops that are not part of ExecuTorch's edge IR op set. The edge IR verifier rejects these ops, but they execute correctly on the portable runtime. We disable the validity check as a workaround.
edge_program = exir.to_edge_transform_and_lower(
exported_program,
partitioner=[partition.xnnpack_partitioner.XnnpackPartitioner()],
constant_methods=metadata,
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
)
et_program = edge_program.to_executorch(
config=exir.ExecutorchBackendConfig(
extract_delegate_segments=True,
memory_planning_pass=exir.passes.MemoryPlanningPass(
alloc_graph_input=False
),
do_quant_fusion_and_const_prop=True,
)
)
4. Save, Load, and Execute¶
Save the quantized XNNPACK program as a .pte file, load it into the ExecuTorch
runtime, and run forward passes at different sequence lengths.
Save the .pte file¶
pte_dir = tempfile.mkdtemp()
pte_path = Path(pte_dir) / "sgt4t_xnnpack.pte"
et_program.save(str(pte_path))
pte_size_mb = pte_path.stat().st_size / 1e6
print(f"XNNPACK .pte: {pte_size_mb:.1f} MB ({pte_path.name})")
XNNPACK .pte: 2365.8 MB (sgt4t_xnnpack.pte)
Load and execute with the ExecuTorch runtime¶
Free the Python-side model objects to reduce peak memory usage before loading
the .pte file. The runtime holds its own copy of the weights.
del et_program, edge_program, exported_program, exportable, sgt4t
gc.collect()
rt = runtime.Runtime.get()
program = rt.load_program(pte_path)
method = program.load_method("forward")
print("Loaded .pte into ExecuTorch runtime.")
[program.cpp:162] InternalConsistency verification requested but not available [cpuinfo_utils.cpp:71] Reading file /sys/devices/soc0/image_version [cpuinfo_utils.cpp:87] Failed to open midr file /sys/devices/soc0/image_version [cpuinfo_utils.cpp:100] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 [cpuinfo_utils.cpp:109] Failed to open midr file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 [cpuinfo_utils.cpp:125] CPU info and manual query on # of cpus dont match.
Loaded .pte into ExecuTorch runtime.
def make_inputs(seq_len: int, vocab_size: int) -> list[torch.Tensor]:
"""Generate dummy inputs for testing the exported programs."""
return [
torch.randint(0, vocab_size, (1, seq_len)),
torch.ones(1, seq_len, dtype=torch.bool),
torch.ones(1, seq_len, dtype=torch.long),
]
VOCAB_SIZE = 128256 # Llama 3.1 vocab size
print("Running forward passes at different sequence lengths:\n")
for seq_len in [2, 8, 32, 128, 512]:
input_ids, noise_mask, attention_mask = make_inputs(seq_len, VOCAB_SIZE)
output = method.execute([input_ids, noise_mask, attention_mask])
print(
f" seq_len={seq_len:>4d} -> output shape={output[0].shape}, dtype={output[0].dtype}"
)
Running forward passes at different sequence lengths: seq_len= 2 -> output shape=torch.Size([1, 2, 4096]), dtype=torch.float32 seq_len= 8 -> output shape=torch.Size([1, 8, 4096]), dtype=torch.float32 seq_len= 32 -> output shape=torch.Size([1, 32, 4096]), dtype=torch.float32 seq_len= 128 -> output shape=torch.Size([1, 128, 4096]), dtype=torch.float32 seq_len= 512 -> output shape=torch.Size([1, 512, 4096]), dtype=torch.float32
Clean up¶
del method, program, rt
shutil.rmtree(pte_dir)
print("Cleaned up .pte files and runtime.")
Cleaned up .pte files and runtime.