Skip to content

tensor

Module for PyTorch tensor utilities.

Functions:

Name Description
cast_to_device

Make a deep copy of value, casting all tensors to the given device and dtype.

collect_devices

Collect all devices in the given value.

collect_floating_point_dtypes

Collect all floating point dtypes in the given value.

hash_tensor_data

Compute the hash of the tensor's data represented as a string.

prepare_for_safe_division

Prepare a tensor for safe division by adding a small value to it which is relative to the scale of the input tensor itself.

cast_to_device

cast_to_device(
    value: T,
    *,
    device: str | int | device | None = None,
    dtype: dtype | None = None
) -> T

Make a deep copy of value, casting all tensors to the given device and dtype.

Adapted from: https://github.com/pytorch/pytorch/blob/49444c3e546bf240bed24a101e747422d1f8a0ee/torch/optim/optimizer.py#L209C1-L225C29.

Parameters:

Name Type Description Default

value

T

The value to recursively copy and cast.

required

device

str | int | device | None

The device to cast tensors to.

None

dtype

dtype | None

The dtype to cast tensors. Only applied to floating point tensors.

None

Returns:

Type Description
T

The copied and casted value.

collect_devices

collect_devices(
    value: (
        Tensor
        | dict[Any, Any]
        | UserDict[Any, Any]
        | Iterable[Any]
        | Any
    ),
) -> set[torch.device]

Collect all devices in the given value.

Parameters:

Name Type Description Default

value

Tensor | dict[Any, Any] | UserDict[Any, Any] | Iterable[Any] | Any

The value to recursively collect devices from.

required

Returns:

Type Description
set[torch.device]

The set of all devices in the given value.

collect_floating_point_dtypes

collect_floating_point_dtypes(
    value: (
        Tensor
        | dict[Any, Any]
        | UserDict[Any, Any]
        | Iterable[Any]
        | Any
    ),
) -> set[torch.dtype]

Collect all floating point dtypes in the given value.

Parameters:

Name Type Description Default

value

Tensor | dict[Any, Any] | UserDict[Any, Any] | Iterable[Any] | Any

The value to recursively collect floating point dtypes from.

required

Returns:

Type Description
set[torch.dtype]

The set of all floating point dtypes in the given value.

hash_tensor_data

hash_tensor_data(tensor: Tensor) -> int

Compute the hash of the tensor's data represented as a string.

Note

Since 0 and -0 have different byte representations, they will produce different hash values.

Parameters:

Name Type Description Default

tensor

Tensor

The tensor whose data to hash.

required

Returns:

Type Description
int

The hash of the tensor's data.

prepare_for_safe_division

prepare_for_safe_division(
    tensor: Tensor, safety_factor: float = 100.0
) -> torch.Tensor

Prepare a tensor for safe division by adding a small value to it which is relative to the scale of the input tensor itself.

Parameters:

Name Type Description Default

tensor

Tensor

The tensor to prepare for safe division.

required

safety_factor

float

A factor to multiply the small value by. This can be used to increase the value added to the tensor, providing more numerical stability at the cost of potentially more bias.

100.0

Returns:

Type Description
torch.Tensor

The tensor with a small value added to it for safe division.