Skip to content

turboquant

TurboQuant MSE algorithm (arXiv:2504.19874) with encode helpers for the proxy.

TurboQuant is a WHT-rotated Lloyd-Max vector quantizer that achieves near-optimal MSE with a fast GPU encode and decode path.

Note
  1. TurboQuant uses the Walsh-Hadamard Transform, which requires a power-of-2 embedding dimension. Almost all production transformer models satisfy this (e.g. 4096, 8192). For models with non-power-of-2 hidden sizes, set SGP_EMBEDDING_COMPRESSION=none.

  2. A note about future non-power-of-2 support: WHT requires dim to be a power of 2. For non-power-of-2 hidden sizes (e.g. Phi-3.5 dim=3072), the rotation would need to be a full dense random orthogonal matrix. If that path is ever added, keep the rotation as a plain matmul (x @ R.T) rather than embedding it in a custom kernel — cuBLAS beats in-kernel GEMM for this operation, as confirmed by vLLM's TurboQuant Triton launcher (triton_turboquant_store.py).

Classes:

Name Description
TurboQuantMSE

TurboQuant MSE-optimal online vector quantizer (arXiv:2504.19874).

TurboQuantPayload

Wire format produced by the proxy and consumed by the vLLM TurboQuant plugin.

Functions:

Name Description
encode_embeddings

Encode a 2-D embedding tensor using a module-level cached quantizer.

packed_dim

Return the number of uint8 elements in the packed tensor for a given (dim, bits) pair.

warmup_encode

Trigger torch.compile for the given dimensions before the first real request.

TurboQuantMSE

TurboQuant MSE-optimal online vector quantizer (arXiv:2504.19874).

Encode: normalize → WHT-rotate → Lloyd-Max assignment → bit-pack. Decode: unpack → centroid lookup → inverse-rotate → rescale.

Tensor operations run on whatever device the input lives on. The rotation matrix and codebook are stored on CPU and transferred lazily on first use per device; subsequent calls on the same device reuse the cached copy.

Parameters:

Name Type Description Default

dim

int

Vector dimension. Must be a power of 2.

required

bits

int

Bits per coordinate. Must be one of {1, 2, 4, 8}.

required

seed

int

RNG seed for the rotation matrix.

42

rotate

bool

If False, skip the WHT rotation (ablation baseline).

True

Examples:

>>> q = TurboQuantMSE(dim=128, bits=4)
>>> x = torch.randn(64, 128)
>>> packed, norms = q.encode(x)
>>> x_hat = q.decode(packed, norms)

Methods:

Name Description
__init__

Initialize the quantizer with a Lloyd-Max codebook and WHT rotation.

decode

Reconstruct vectors from packed indices and stored norms.

encode

Quantize vectors to packed indices plus stored norms.

lower_bound

Yao's minimax lower bound: 4^{-b} (Theorem 3, arXiv:2504.19874).

mse

Average per-vector squared L2 reconstruction error.

upper_bound

TurboQuant MSE guarantee: (sqrt(3*pi)/2) * 4^{-b} (Theorem 1, arXiv:2504.19874).

__init__

__init__(
    dim: int, bits: int, seed: int = 42, rotate: bool = True
) -> None

Initialize the quantizer with a Lloyd-Max codebook and WHT rotation.

Raises:

Type Description
ValueError

If bits is not one of the supported bit-widths, or (when rotate=True) dim is not a power of 2.

decode

decode(packed: Tensor, norms: Tensor) -> torch.Tensor

Reconstruct vectors from packed indices and stored norms.

Parameters:

Name Type Description Default

packed

Tensor

Shape (..., packed_dim(dim, bits)), dtype uint8. See packed_dim for the expected width for each (dim, bits) combination.

required

norms

Tensor

Shape (...,), dtype float32. Norms from encode().

required

Returns:

Type Description
torch.Tensor

Reconstructed tensor of shape (..., dim), dtype float32.

encode

encode(x: Tensor) -> tuple[torch.Tensor, torch.Tensor]

Quantize vectors to packed indices plus stored norms.

Parameters:

Name Type Description Default

x

Tensor

Input tensor of shape (..., dim).

required

Returns:

Type Description
tuple[torch.Tensor, torch.Tensor]

A tuple (packed, norms). packed has shape (..., packed_dim(dim, bits))

tuple[torch.Tensor, torch.Tensor]

and dtype uint8; norms has shape (...,) in float32. See packed_dim

tuple[torch.Tensor, torch.Tensor]

for the mapping from (dim, bits) to packed width.

lower_bound

lower_bound() -> float

Yao's minimax lower bound: 4^{-b} (Theorem 3, arXiv:2504.19874).

Returns:

Type Description
float

The information-theoretic lower bound on per-vector MSE for unit vectors.

mse

mse(x: Tensor) -> float

Average per-vector squared L2 reconstruction error.

Parameters:

Name Type Description Default

x

Tensor

Input tensor of shape (..., dim), dtype float32.

required

Returns:

Type Description
float

Scalar mean squared error over the batch.

upper_bound

upper_bound() -> float

TurboQuant MSE guarantee: (sqrt(3*pi)/2) * 4^{-b} (Theorem 1, arXiv:2504.19874).

Returns:

Type Description
float

The theoretical upper bound on per-vector MSE, approximately 1.54 * 4^{-b}.

TurboQuantPayload

Bases: TypedDict

Wire format produced by the proxy and consumed by the vLLM TurboQuant plugin.

Serialised with torch.save and sent as a base64-encoded prompt_embeds field.

encode_embeddings

encode_embeddings(
    x: Tensor, bits: int = 4
) -> tuple[torch.Tensor, torch.Tensor]

Encode a 2-D embedding tensor using a module-level cached quantizer.

The codebook and rotation matrix are initialized once per (hidden_size, bits) pair and reused across calls. The input must be 2-D (num_tokens, hidden_size); reshape before calling if needed.

Parameters:

Name Type Description Default

x

Tensor

Input tensor of shape (num_tokens, hidden_size).

required

bits

int

Bits per coordinate. Defaults to 4.

4

Returns:

Type Description
tuple[torch.Tensor, torch.Tensor]

A tuple (packed, norms) ready to include in the TurboQuant wire payload.

Raises:

Type Description
ValueError

If x is not a 2-D tensor.

packed_dim

packed_dim(dim: int, bits: int) -> int

Return the number of uint8 elements in the packed tensor for a given (dim, bits) pair.

Uses the general formula ceil(dim * bits / 8), which is exact for all supported bit-widths: - bits=1: ceil(dim / 8) - bits=2: ceil(dim / 4) - bits=3: ceil(dim * 3 / 8) — 8 indices per 3 bytes (24-bit groups) - bits=4: ceil(dim / 2) - bits=8: dim — one index per byte, no packing

Parameters:

Name Type Description Default

dim

int

Embedding dimension.

required

bits

int

Bits per coordinate.

required

Returns:

Type Description
int

Number of uint8 elements per embedding vector in the packed representation.

Raises:

Type Description
ValueError

If bits is not one of the supported bit-widths.

warmup_encode

warmup_encode(embedding_dim: int, bits: int = 4) -> None

Trigger torch.compile for the given dimensions before the first real request.

Running a dummy encode at startup ensures the Triton kernel is compiled during initialization rather than on the first live request (~5 s on GPU). Subsequent calls for the same (embedding_dim, bits) pair are instant cache hits.

Parameters:

Name Type Description Default

embedding_dim

int

Embedding dimension to compile for.

required

bits

int

Bits per coordinate. Defaults to 4.

4