exportable_text_sgt
Torch export compliant Stained Glass Transform For Text client for use with ExecuTorch runtimes.
This module provides an export-compatible wrapper around StainedGlassTransformForText that
bypasses the non-exportable hook machinery in BaseNoiseLayer.__call__ and instead calls the
estimator submodules directly, inlining the noise formula and noise-mask application as pure tensor operations.
Note
The wrapper requires that the noise layer does not use a
PercentMasker (or any masker that relies on
data-dependent shapes like masked_select), because those operations cannot be
traced by torch.export. Most production SGT4T models use percent_to_mask=None.
Note
Only SDPA-family attention implementations are compatible with torch.export. The wrapper
defaults to standard SDPA but can optionally use the optimum-executorch custom SDPA (a CPU Flash Attention implementation) by passing use_custom_sdpa=True.
See: https://github.com/pytorch/executorch/blob/a4322c71c3a97e79e0454a8223db214b010f1193/extension/llm/README.md?plain=1#L40
Note
When lowering to ExecuTorch via to_edge(), you must pass
EdgeCompileConfig(_check_ir_validity=False). HuggingFace transformers constructs
causal attention masks using torch.vmap (via masking_utils._vmap_for_bhqkv),
which traces into _vmap_increment_nesting / _vmap_decrement_nesting ops that
are not part of ExecuTorch's edge IR dialect. These ops execute correctly on the
portable runtime despite failing the strict edge IR validation.
Note
ExecuTorch's portable runtime uses upper-bound memory planning: it pre-allocates
buffers sized for the maximum dynamic sequence length at program load time. When
export() is called without explicit dynamic_shapes, the max is derived from the
model config (e.g. max_position_embeddings=131072 for Llama 3.1), which can cause
the runtime to OOM even on systems with ample RAM. Pass a dynamic_shapes argument
with a torch.export.Dim capped to your actual deployment needs to avoid this.
See: https://docs.pytorch.org/executorch/main/using-executorch-export.html#supporting-varying-input-sizes-dynamic-shapes
Example
from stainedglass_core import transform as sg_transform from stainedglass_core.integrations import executorch as sg_executorch sgt4t = sg_transform.StainedGlassTransformForText.from_pretrained( ... "MODEL_ID", noise_layer_attention="sdpa" ... ).eval() # doctest: +SKIP exportable = sg_executorch.ExportableStainedGlassTransformForText( ... sgt4t ... ) # doctest: +SKIP exported_program = exportable.export() # doctest: +SKIP
Classes:
| Name | Description |
|---|---|
ExportableStainedGlassTransformForText |
|
ExportableStainedGlassTransformForText
¶
Bases: Module
Torch export compliant wrapper for creating protected input embeddings using Stained Glass Transform For Text.
This wrapper bypasses BaseNoiseLayer.__call__ which uses non-exportable hook
context managers (RNG forking, precision casting, runtime hook registration), and instead calls the mean/std estimators directly via
the standard nn.Module.__call__ path, then applies the affine SGT formula and noise-mask as inline tensor operations.
The exported graph contains:
- The token embedding table from the base model.
- The mean estimator transformer.
- The std estimator transformer.
Attributes:
| Name | Type | Description |
|---|---|---|
embed_tokens |
Module
|
Token embedding layer from the base model. |
mean_estimator |
Estimator
|
Estimator that computes the mean shift component. |
std_estimator |
Estimator
|
Estimator that computes the standard-deviation component. |
Warning
This API is experimental and subject to change: The SGT4T exportable wrapper for ExecuTorch is experimental and may be subject to changes in future releases.
Added in version v3.25.0. Added ExportableStainedGlassTransformForText wrapper for SGT4T support with torch.export and ExecuTorch runtimes.
Methods:
| Name | Description |
|---|---|
__init__ |
Initialize the ExportableStainedGlassTransformForText. |
export |
Export the module as a |
forward |
Apply Stained Glass Transform to input token IDs embeddings. |
get_metadata |
Build a metadata dictionary for ExecuTorch |
__init__
¶
__init__(
sgt4t: StainedGlassTransformForText,
*,
use_custom_sdpa: bool = False
) -> None
Initialize the ExportableStainedGlassTransformForText.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
StainedGlassTransformForText
|
A pretrained |
required |
|
bool
|
If |
False
|
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If the noise layer's std estimator uses a masker, which is
not compatible with |
export
¶
export(
example_input_ids: Tensor | None = None,
example_noise_mask: Tensor | None = None,
example_attention_mask: Tensor | None = None,
dynamic_shapes: dict[str, dict[int, Any]] | None = None,
strict: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> torch.export.ExportedProgram
Export the module as a torch.export.ExportedProgram.
When called without arguments, deterministic example inputs are built from the
model config, and a single shared torch.export.Dim is used to mark the sequence-length
axis as dynamic across all inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor | None
|
Example tensor of shape |
None
|
|
Tensor | None
|
Example boolean tensor of shape |
None
|
|
Tensor | None
|
Example tensor of shape |
None
|
|
dict[str, dict[int, Any]] | None
|
Dynamic shape specification for |
None
|
|
bool
|
Whether to use strict export mode. Defaults to |
False
|
|
bool
|
If This is required for models whose RoPE implementation produces guards that the
symbolic solver cannot algebraically verify. For example, Llama 3.1 RoPE scaling
generates the guard |
False
|
Returns:
| Type | Description |
|---|---|
torch.export.ExportedProgram
|
The exported program ready to be lowered to ExecuTorch. |
forward
¶
forward(
input_ids: Tensor,
noise_mask: Tensor,
attention_mask: Tensor,
) -> torch.Tensor
Apply Stained Glass Transform to input token IDs embeddings.
This method replicates the computation of BaseNoiseLayer.forward and the noise-mask
application hook, but using only export-compatible tensor operations.
Noise is sampled internally via torch.randn, which is supported by torch.export
and ExecuTorch's portable runtime (aten::randn.out).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Token IDs of shape |
required |
|
Tensor
|
Boolean tensor of shape |
required |
|
Tensor
|
Tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
Tensor of shape |
get_metadata
¶
Build a metadata dictionary for ExecuTorch constant_methods.
Delegates to optimum.exporters.executorch.utils.save_config_to_constant_methods
to produce the standard model-architecture metadata (dtype, vocab size, head
dimensions, KV-cache settings, etc.), then extends it with SGT-specific entries.
Note: This metadata is recorded in the exported model file (.pte).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
Additional key-value pairs to include or override in the metadata. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, int | float | bool | str]
|
A dictionary mapping getter method names to their constant values. |