Skip to content

dtypes

Module for PyTorch dtype utilities.

Functions:

Name Description
compute_in_precision

Create a decorator to compute a function in the specified dtype.

default_dtype

Context manager to temporarily change the default dtype of torch.

from_str

Convert a string representation of a dtype to a torch.dtype.

compute_in_precision

compute_in_precision(
    dtype: dtype,
) -> collections.abc.Callable[
    [collections.abc.Callable[~P, torch.Tensor]],
    collections.abc.Callable[~P, torch.Tensor],
]

Create a decorator to compute a function in the specified dtype.

Parameters:

Name Type Description Default

dtype

dtype

The dtype to determine the in-precision.

required

Returns:

Type Description
collections.abc.Callable[[collections.abc.Callable[~P, torch.Tensor]], collections.abc.Callable[~P, torch.Tensor]]

A wrapped function that computes the in-precision based on the provided dtype.

default_dtype

default_dtype(
    dtype: dtype,
) -> collections.abc.Generator[None]

Context manager to temporarily change the default dtype of torch.

Parameters:

Name Type Description Default

dtype

dtype

The dtype to temporarily set as the default. All dtypes supported by torch.set_default_dtype are valid.

required

Yields:

Type Description
collections.abc.Generator[None]

None.

Examples:

New tensors created within the context will have the new default dtype (unless explicitly overridden by the dtype argument).

>>> torch.tensor([1.0]).dtype
torch.float32
>>> with default_dtype(torch.float64):
...     torch.tensor([1.0]).dtype
torch.float64

The default dtype is restored after the context exits.

>>> torch.get_default_dtype()
torch.float32
>>> with default_dtype(torch.float64):
...     torch.get_default_dtype()
torch.float64
>>> torch.get_default_dtype()
torch.float32

from_str

from_str(dtype_str: str) -> <class 'torch.dtype'>

Convert a string representation of a dtype to a torch.dtype.

Parameters:

Name Type Description Default

dtype_str

str

The string representation of the dtype.

required

Returns:

Type Description
<class 'torch.dtype'>

The corresponding torch.dtype.