Skip to content

torchao

torchao integration modules.

Modules:

Name Description
xnnpack_quantizer

XNNPACK-compatible quantization for Stained Glass Transform modules using torchao.

Classes:

Name Description
StainedGlassTransformForTextQuantizer

Quantizer for StainedGlassTransformForText modules.

Functions:

Name Description
build_embedding_config

Build an XNNPACK-compatible embedding quantization config from dtype parameters.

build_linear_config

Build an XNNPACK-compatible linear quantization config from string parameters.

quantize_module_

Quantize a module's embedding and linear layers in-place using torchao.

Attributes:

Name Type Description
WeightDtype TypeAlias

Allowed weight dtype strings for XNNPACK-compatible quantization configs.

WeightDtype module-attribute

WeightDtype: TypeAlias = Literal['int4', 'int8']

Allowed weight dtype strings for XNNPACK-compatible quantization configs.

"int4" and "int8" resolve to the corresponding torch.dtype integer types. Pass None to the build_*_config helpers to skip quantizing the corresponding layer type.

StainedGlassTransformForTextQuantizer

Quantizer for StainedGlassTransformForText modules.

Switch the transform to eval mode for quantization, then delegate to quantize_module_ for the actual quantization.

Example::

>>> from stainedglass_core import transform as sg_transform
>>> from stainedglass_core.integrations import torchao as sg_torchao

>>> sgt = sg_transform.StainedGlassTransformForText.from_pretrained(
...     "MODEL_ID", noise_layer_attention="sdpa"
... )  # doctest: +SKIP
>>> sgt = sg_torchao.StainedGlassTransformForTextQuantizer.quantize(sgt)  # doctest: +SKIP

Added in version v3.22.0. Added API for quantizing StainedGlassTransformForText modules using torchao.

Warning

This API is experimental and subject to change: The StainedGlassTransformForTextQuantizer API is currently under development and is subject to change.

Methods:

Name Description
quantize

Quantize a StainedGlassTransformForText in-place.

quantize classmethod

quantize(sgt: StainedGlassTransformForText, *, embedding_config: IntxWeightOnlyConfig | _Unset | None = <stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>, linear_config: Int8DynamicActivationIntxWeightConfig | _Unset | None = <stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>, filter_fn: Callable[[Module, str], bool] | None = None) -> sg_text.StainedGlassTransformForText

Quantize a StainedGlassTransformForText in-place.

Quantize embedding and linear layers using XNNPACK-compatible torchao quantization schemes.

Parameters:

Name Type Description Default

sgt

StainedGlassTransformForText

The StainedGlassTransformForText to quantize.

required

embedding_config

IntxWeightOnlyConfig | _Unset | None

Config for nn.Embedding quantization. Pass None to skip. Defaults to IntxWeightOnlyConfig(weight_dtype=int8, granularity=PerAxis(0)).

<stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>

linear_config

Int8DynamicActivationIntxWeightConfig | _Unset | None

Config for nn.Linear quantization. Pass None to skip. Defaults to Int8DynamicActivationIntxWeightConfig(weight_dtype=int4, weight_granularity=PerGroup(32)).

<stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>

filter_fn

Callable[[Module, str], bool] | None

Optional filter function for linear quantization.

None

Returns:

Type Description
sg_text.StainedGlassTransformForText

The same StainedGlassTransformForText reference, modified

sg_text.StainedGlassTransformForText

in-place.

build_embedding_config

build_embedding_config(
    weight_dtype: WeightDtype | None = "int8",
) -> (
    torchao.quantization.quant_api.IntxWeightOnlyConfig
    | None
)

Build an XNNPACK-compatible embedding quantization config from dtype parameters.

Create an IntxWeightOnlyConfig with PerAxis(0) granularity suitable for nn.Embedding layers in the XNNPACK backend.

Parameters:

Name Type Description Default

weight_dtype

WeightDtype | None

Torch integer dtype name ("int8" or "int4"). Pass None to skip embedding quantization.

'int8'

Returns:

Type Description
torchao.quantization.quant_api.IntxWeightOnlyConfig | None

An IntxWeightOnlyConfig, or None when weight_dtype is None.

build_linear_config

build_linear_config(
    weight_dtype: WeightDtype | None = "int4",
    weight_group_size: int = 32,
) -> (
    torchao.quantization.quant_api.Int8DynamicActivationIntxWeightConfig
    | None
)

Build an XNNPACK-compatible linear quantization config from string parameters.

Create an Int8DynamicActivationIntxWeightConfig with PerGroup granularity suitable for nn.Linear layers in the XNNPACK backend.

Parameters:

Name Type Description Default

weight_dtype

WeightDtype | None

Torch integer dtype name ("int4" or "int8"). Pass None to skip linear quantization.

'int4'

weight_group_size

int

Group size for PerGroup weight granularity.

32

Returns:

Type Description
torchao.quantization.quant_api.Int8DynamicActivationIntxWeightConfig | None

An Int8DynamicActivationIntxWeightConfig, or None when

torchao.quantization.quant_api.Int8DynamicActivationIntxWeightConfig | None

weight_dtype is None.

quantize_module_

quantize_module_(module: Module, *, embedding_config: IntxWeightOnlyConfig | _Unset | None = <stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>, linear_config: Int8DynamicActivationIntxWeightConfig | IntxWeightOnlyConfig | _Unset | None = <stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>, filter_fn: Callable[[Module, str], bool] | None = None) -> None

Quantize a module's embedding and linear layers in-place using torchao.

Apply XNNPACK-compatible quantization schemes: IntxWeightOnlyConfig for nn.Embedding layers and Int8DynamicActivationIntxWeightConfig for nn.Linear layers.

The default linear filter function safely skips linear layers with None weights, which occur in StainedGlassTransformForText when unused base model layers are removed during serialization.

Parameters:

Name Type Description Default

module

Module

The nn.Module to quantize in-place.

required

embedding_config

IntxWeightOnlyConfig | _Unset | None

Config for nn.Embedding quantization. Pass None to skip embedding quantization. Defaults to IntxWeightOnlyConfig(weight_dtype=int8, granularity=PerAxis(0)).

<stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>

linear_config

Int8DynamicActivationIntxWeightConfig | IntxWeightOnlyConfig | _Unset | None

Config for nn.Linear quantization. Pass None to skip linear quantization. Defaults to Int8DynamicActivationIntxWeightConfig(weight_dtype=int4, weight_granularity=PerGroup(32)).

<stainedglass_core.integrations.torchao.xnnpack_quantizer._Unset object at 0x7f5b70b99940>

filter_fn

Callable[[Module, str], bool] | None

Optional filter function for linear quantization. Receives (module, fully_qualified_name) and returns True to quantize. Defaults to a filter that selects nn.Linear modules with non-None weights.

None
Example

from stainedglass_core.integrations import torchao as sg_torchao

sg_torchao.quantize_module_(model) # doctest: +SKIP sg_torchao.quantize_module_(model, linear_config=None) # doctest: +SKIP sg_torchao.quantize_module_(model, embedding_config=None) # doctest: +SKIP

Warning

This API is experimental and subject to change: The quantize_module_ API is currently under development and is subject to change.