Skip to content

executorch

ExecuTorch integration for Stained Glass Core SGT Models.

Provides utilities for creating torch export / ExecuTorch compliant SGT modules to be used with ExecuTorch runtimes.

Modules:

Name Description
exportable_text_sgt

Torch export compliant Stained Glass Transform For Text client for use with ExecuTorch runtimes.

Classes:

Name Description
ExportableStainedGlassTransformForText

Torch export compliant wrapper for creating protected input embeddings using Stained Glass Transform For Text.

ExportableStainedGlassTransformForText

Bases: Module

Torch export compliant wrapper for creating protected input embeddings using Stained Glass Transform For Text.

This wrapper bypasses BaseNoiseLayer.__call__ which uses non-exportable hook context managers (RNG forking, precision casting, runtime hook registration), and instead calls the mean/std estimators directly via the standard nn.Module.__call__ path, then applies the affine SGT formula and noise-mask as inline tensor operations.

The exported graph contains:

  • The token embedding table from the base model.
  • The mean estimator transformer.
  • The std estimator transformer.

Attributes:

Name Type Description
embed_tokens Module

Token embedding layer from the base model.

mean_estimator Estimator

Estimator that computes the mean shift component.

std_estimator Estimator

Estimator that computes the standard-deviation component.

Warning

This API is experimental and subject to change: The SGT4T exportable wrapper for ExecuTorch is experimental and may be subject to changes in future releases.

Added in version v3.25.0. Added ExportableStainedGlassTransformForText wrapper for SGT4T support with torch.export and ExecuTorch runtimes.

Methods:

Name Description
__init__

Initialize the ExportableStainedGlassTransformForText.

export

Export the module as a torch.export.ExportedProgram.

forward

Apply Stained Glass Transform to input token IDs embeddings.

get_metadata

Build a metadata dictionary for ExecuTorch constant_methods.

__init__

__init__(
    sgt4t: StainedGlassTransformForText,
    *,
    use_custom_sdpa: bool = False
) -> None

Initialize the ExportableStainedGlassTransformForText.

Parameters:

Name Type Description Default

sgt4t

StainedGlassTransformForText

A pretrained StainedGlassTransformForText instance.

required

use_custom_sdpa

bool

If True, use the optimum-executorch custom SDPA implementation (CPU Flash Attention) instead of PyTorch's built-in SDPA.

False

Raises:

Type Description
NotImplementedError

If the noise layer's std estimator uses a masker, which is not compatible with torch.export.

export

export(
    example_input_ids: Tensor | None = None,
    example_noise_mask: Tensor | None = None,
    example_attention_mask: Tensor | None = None,
    dynamic_shapes: dict[str, dict[int, Any]] | None = None,
    strict: bool = False,
    prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> torch.export.ExportedProgram

Export the module as a torch.export.ExportedProgram.

When called without arguments, deterministic example inputs are built from the model config, and a single shared torch.export.Dim is used to mark the sequence-length axis as dynamic across all inputs.

Parameters:

Name Type Description Default

example_input_ids

Tensor | None

Example tensor of shape (batch_size, seq_length). If None, a zero-filled tensor of shape (1, seq_length) is generated using the vocabulary size from the model config.

None

example_noise_mask

Tensor | None

Example boolean tensor of shape (batch_size, seq_length). If None, an all-True tensor is generated.

None

example_attention_mask

Tensor | None

Example tensor of shape (batch_size, seq_length). If None, an all-ones long tensor is generated.

None

dynamic_shapes

dict[str, dict[int, Any]] | None

Dynamic shape specification for torch.export.export. Keys are argument names and values are dicts mapping dimension indices to torch.export.Dim objects. If None, sequence length (dimension 1) is made dynamic for all inputs.

None

strict

bool

Whether to use strict export mode. Defaults to False (non-strict / dynamo).

False

prefer_deferred_runtime_asserts_over_guards

bool

If True, symbolic guards that torch.export cannot prove statically are emitted as runtime assertions instead of raising ConstraintViolationError at export time.

This is required for models whose RoPE implementation produces guards that the symbolic solver cannot algebraically verify. For example, Llama 3.1 RoPE scaling generates the guard Min(head_dim * seq_len, factor * seq_len) == head_dim * seq_len which is trivially true for any positive seq_len at runtime, but cannot be proved symbolically. The deferred assertion is still checked at execution time with the concrete seq_len.

False

Returns:

Type Description
torch.export.ExportedProgram

The exported program ready to be lowered to ExecuTorch.

forward

forward(
    input_ids: Tensor,
    noise_mask: Tensor,
    attention_mask: Tensor,
) -> torch.Tensor

Apply Stained Glass Transform to input token IDs embeddings.

This method replicates the computation of BaseNoiseLayer.forward and the noise-mask application hook, but using only export-compatible tensor operations.

Noise is sampled internally via torch.randn, which is supported by torch.export and ExecuTorch's portable runtime (aten::randn.out).

Parameters:

Name Type Description Default

input_ids

Tensor

Token IDs of shape (batch_size, seq_length).

required

noise_mask

Tensor

Boolean tensor of shape (batch_size, seq_length) indicating which tokens to apply noise to (True = transform, False = keep original).

required

attention_mask

Tensor

Tensor of shape (batch_size, seq_length) indicating which tokens should be attended to (1 = attend, 0 = ignore).

required

Returns:

Type Description
torch.Tensor

Tensor of shape (batch_size, seq_length, embedding_dim) with transformed embeddings.

get_metadata

get_metadata(
    **overrides: Any,
) -> dict[str, int | float | bool | str]

Build a metadata dictionary for ExecuTorch constant_methods.

Delegates to optimum.exporters.executorch.utils.save_config_to_constant_methods to produce the standard model-architecture metadata (dtype, vocab size, head dimensions, KV-cache settings, etc.), then extends it with SGT-specific entries.

Note: This metadata is recorded in the exported model file (.pte).

Parameters:

Name Type Description Default

**overrides

Any

Additional key-value pairs to include or override in the metadata.

required

Returns:

Type Description
dict[str, int | float | bool | str]

A dictionary mapping getter method names to their constant values.