dtypes
Module for PyTorch dtype utilities.
Functions:
| Name | Description |
|---|---|
compute_in_precision |
Create a decorator to compute a function in the specified dtype. |
default_dtype |
Context manager to temporarily change the default dtype of torch. |
from_str |
Convert a string representation of a dtype to a |
compute_in_precision
¶
compute_in_precision(
dtype: dtype,
) -> collections.abc.Callable[
[collections.abc.Callable[~P, torch.Tensor]],
collections.abc.Callable[~P, torch.Tensor],
]
Create a decorator to compute a function in the specified dtype.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dtype
|
The dtype to determine the in-precision. |
required |
Returns:
| Type | Description |
|---|---|
collections.abc.Callable[[collections.abc.Callable[~P, torch.Tensor]], collections.abc.Callable[~P, torch.Tensor]]
|
A wrapped function that computes the in-precision based on the provided dtype. |
default_dtype
¶
default_dtype(
dtype: dtype,
) -> collections.abc.Generator[None]
Context manager to temporarily change the default dtype of torch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dtype
|
The dtype to temporarily set as the default. All dtypes supported by torch.set_default_dtype are valid. |
required |
Yields:
| Type | Description |
|---|---|
collections.abc.Generator[None]
|
None. |
Examples:
New tensors created within the context will have the new default dtype (unless explicitly overridden by the dtype argument).
>>> torch.tensor([1.0]).dtype
torch.float32
>>> with default_dtype(torch.float64):
... torch.tensor([1.0]).dtype
torch.float64
The default dtype is restored after the context exits.