Skip to content

_mps

Patches to torch.mps functions from https://github.com/pytorch/pytorch/pull/124676 to fix MPS compatibility with torch.random.fork_rng.

Functions:

Name Description
device_count

Returns the number of available MPS devices.

get_rng_state

Returns the random number generator state as a ByteTensor.

set_rng_state

Sets the random number generator state.

device_count

device_count() -> int

Returns the number of available MPS devices.

get_rng_state

get_rng_state(
    device: int | str | device = "mps",
) -> torch.Tensor

Returns the random number generator state as a ByteTensor.

Parameters:

Name Type Description Default

device

device or int

The device to return the RNG state of. Default: 'mps' (i.e., torch.device('mps'), the current MPS device).

'mps'

set_rng_state

set_rng_state(
    new_state: Tensor, device: int | str | device = "mps"
) -> None

Sets the random number generator state.

Parameters:

Name Type Description Default

new_state

ByteTensor

The desired state

required

device

device or int

The device to set the RNG state. Default: 'mps' (i.e., torch.device('mps'), the current MPS device).

'mps'