Skip to content

exportable_text_sgt

Torch export compliant Stained Glass Transform For Text client for use with ExecuTorch runtimes.

This module provides an export-compatible wrapper around StainedGlassTransformForText that bypasses the non-exportable hook machinery in BaseNoiseLayer.__call__ and instead calls the estimator submodules directly, inlining the noise formula and noise-mask application as pure tensor operations.

Note

The wrapper requires that the noise layer does not use a PercentMasker (or any masker that relies on data-dependent shapes like masked_select), because those operations cannot be traced by torch.export. Most production SGT4T models use percent_to_mask=None.

Note

Only SDPA-family attention implementations are compatible with torch.export. The wrapper defaults to standard SDPA but can optionally use the optimum-executorch custom SDPA (a CPU Flash Attention implementation) by passing use_custom_sdpa=True. See: https://github.com/pytorch/executorch/blob/a4322c71c3a97e79e0454a8223db214b010f1193/extension/llm/README.md?plain=1#L40

Note

When lowering to ExecuTorch via to_edge(), you must pass EdgeCompileConfig(_check_ir_validity=False). HuggingFace transformers constructs causal attention masks using torch.vmap (via masking_utils._vmap_for_bhqkv), which traces into _vmap_increment_nesting / _vmap_decrement_nesting ops that are not part of ExecuTorch's edge IR dialect. These ops execute correctly on the portable runtime despite failing the strict edge IR validation.

Note

ExecuTorch's portable runtime uses upper-bound memory planning: it pre-allocates buffers sized for the maximum dynamic sequence length at program load time. When export() is called without explicit dynamic_shapes, the max is derived from the model config (e.g. max_position_embeddings=131072 for Llama 3.1), which can cause the runtime to OOM even on systems with ample RAM. Pass a dynamic_shapes argument with a torch.export.Dim capped to your actual deployment needs to avoid this. See: https://docs.pytorch.org/executorch/main/using-executorch-export.html#supporting-varying-input-sizes-dynamic-shapes

Example

from stainedglass_core import transform as sg_transform from stainedglass_core.integrations import executorch as sg_executorch sgt4t = sg_transform.StainedGlassTransformForText.from_pretrained( ... "MODEL_ID", noise_layer_attention="sdpa" ... ).eval() # doctest: +SKIP exportable = sg_executorch.ExportableStainedGlassTransformForText( ... sgt4t ... ) # doctest: +SKIP exported_program = exportable.export() # doctest: +SKIP

Classes:

Name Description
ExportableStainedGlassTransformForText

Torch export compliant wrapper for creating protected input embeddings using Stained Glass Transform For Text.

ExportableStainedGlassTransformForText

Bases: Module

Torch export compliant wrapper for creating protected input embeddings using Stained Glass Transform For Text.

This wrapper bypasses BaseNoiseLayer.__call__ which uses non-exportable hook context managers (RNG forking, precision casting, runtime hook registration), and instead calls the mean/std estimators directly via the standard nn.Module.__call__ path, then applies the affine SGT formula and noise-mask as inline tensor operations.

The exported graph contains:

  • The token embedding table from the base model.
  • The mean estimator transformer.
  • The std estimator transformer.

Attributes:

Name Type Description
embed_tokens Module

Token embedding layer from the base model.

mean_estimator Estimator

Estimator that computes the mean shift component.

std_estimator Estimator

Estimator that computes the standard-deviation component.

Warning

This API is experimental and subject to change: The SGT4T exportable wrapper for ExecuTorch is experimental and may be subject to changes in future releases.

Added in version v3.25.0. Added ExportableStainedGlassTransformForText wrapper for SGT4T support with torch.export and ExecuTorch runtimes.

Methods:

Name Description
__init__

Initialize the ExportableStainedGlassTransformForText.

export

Export the module as a torch.export.ExportedProgram.

forward

Apply Stained Glass Transform to input token IDs embeddings.

get_metadata

Build a metadata dictionary for ExecuTorch constant_methods.

__init__

__init__(
    sgt4t: StainedGlassTransformForText,
    *,
    use_custom_sdpa: bool = False
) -> None

Initialize the ExportableStainedGlassTransformForText.

Parameters:

Name Type Description Default

sgt4t

StainedGlassTransformForText

A pretrained StainedGlassTransformForText instance.

required

use_custom_sdpa

bool

If True, use the optimum-executorch custom SDPA implementation (CPU Flash Attention) instead of PyTorch's built-in SDPA.

False

Raises:

Type Description
NotImplementedError

If the noise layer's std estimator uses a masker, which is not compatible with torch.export.

export

export(
    example_input_ids: Tensor | None = None,
    example_noise_mask: Tensor | None = None,
    example_attention_mask: Tensor | None = None,
    dynamic_shapes: dict[str, dict[int, Any]] | None = None,
    strict: bool = False,
    prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> torch.export.ExportedProgram

Export the module as a torch.export.ExportedProgram.

When called without arguments, deterministic example inputs are built from the model config, and a single shared torch.export.Dim is used to mark the sequence-length axis as dynamic across all inputs.

Parameters:

Name Type Description Default

example_input_ids

Tensor | None

Example tensor of shape (batch_size, seq_length). If None, a zero-filled tensor of shape (1, seq_length) is generated using the vocabulary size from the model config.

None

example_noise_mask

Tensor | None

Example boolean tensor of shape (batch_size, seq_length). If None, an all-True tensor is generated.

None

example_attention_mask

Tensor | None

Example tensor of shape (batch_size, seq_length). If None, an all-ones long tensor is generated.

None

dynamic_shapes

dict[str, dict[int, Any]] | None

Dynamic shape specification for torch.export.export. Keys are argument names and values are dicts mapping dimension indices to torch.export.Dim objects. If None, sequence length (dimension 1) is made dynamic for all inputs.

None

strict

bool

Whether to use strict export mode. Defaults to False (non-strict / dynamo).

False

prefer_deferred_runtime_asserts_over_guards

bool

If True, symbolic guards that torch.export cannot prove statically are emitted as runtime assertions instead of raising ConstraintViolationError at export time.

This is required for models whose RoPE implementation produces guards that the symbolic solver cannot algebraically verify. For example, Llama 3.1 RoPE scaling generates the guard Min(head_dim * seq_len, factor * seq_len) == head_dim * seq_len which is trivially true for any positive seq_len at runtime, but cannot be proved symbolically. The deferred assertion is still checked at execution time with the concrete seq_len.

False

Returns:

Type Description
torch.export.ExportedProgram

The exported program ready to be lowered to ExecuTorch.

forward

forward(
    input_ids: Tensor,
    noise_mask: Tensor,
    attention_mask: Tensor,
) -> torch.Tensor

Apply Stained Glass Transform to input token IDs embeddings.

This method replicates the computation of BaseNoiseLayer.forward and the noise-mask application hook, but using only export-compatible tensor operations.

Noise is sampled internally via torch.randn, which is supported by torch.export and ExecuTorch's portable runtime (aten::randn.out).

Parameters:

Name Type Description Default

input_ids

Tensor

Token IDs of shape (batch_size, seq_length).

required

noise_mask

Tensor

Boolean tensor of shape (batch_size, seq_length) indicating which tokens to apply noise to (True = transform, False = keep original).

required

attention_mask

Tensor

Tensor of shape (batch_size, seq_length) indicating which tokens should be attended to (1 = attend, 0 = ignore).

required

Returns:

Type Description
torch.Tensor

Tensor of shape (batch_size, seq_length, embedding_dim) with transformed embeddings.

get_metadata

get_metadata(
    **overrides: Any,
) -> dict[str, int | float | bool | str]

Build a metadata dictionary for ExecuTorch constant_methods.

Delegates to optimum.exporters.executorch.utils.save_config_to_constant_methods to produce the standard model-architecture metadata (dtype, vocab size, head dimensions, KV-cache settings, etc.), then extends it with SGT-specific entries.

Note: This metadata is recorded in the exported model file (.pte).

Parameters:

Name Type Description Default

**overrides

Any

Additional key-value pairs to include or override in the metadata.

required

Returns:

Type Description
dict[str, int | float | bool | str]

A dictionary mapping getter method names to their constant values.