Skip to content

distributed

Modules:

Name Description
fsdp
sharding

Classes:

Name Description
DistModule

A module that wraps another module and distributes its state_dict tensors across multiple processes when loading.

Functions:

Name Description
apply_fsdp2

Apply FSDP2 to the model.

load_state_dict

Patch [accelerate.utils.modeling.load_state_dict][] to only load the state_dict if the

DistModule

Bases: Module

A module that wraps another module and distributes its state_dict tensors across multiple processes when loading.

Parameters:

Name Type Description Default

module

Module

The module to wrap and distribute the state_dict tensors for.

required

Methods:

Name Description
__getattr__

Delegate attribute access to the wrapped module.

__setattr__

Delegate attribute setting to the wrapped module.

forward

Forward the inputs to the wrapped module.

load_state_dict

Distribute and broadcast state_dict tensors from rank 0 to all other ranks.

state_dict

Return the state_dict of the wrapped module.

__getattr__

__getattr__(name: str) -> Any

Delegate attribute access to the wrapped module.

Parameters:

Name Type Description Default

name

str

The attribute name to access.

required

Returns:

Type Description
Any

The attribute of the wrapped module.

__setattr__

__setattr__(name: str, value: Any) -> None

Delegate attribute setting to the wrapped module.

Parameters:

Name Type Description Default

name

str

The attribute name to set.

required

value

Any

The value to set the attribute to.

required

forward

forward(*args: Any, **kwargs: Any) -> Any

Forward the inputs to the wrapped module.

Parameters:

Name Type Description Default

args

Any

The positional arguments to pass to the wrapped module.

required

kwargs

Any

The keyword arguments to pass to the wrapped module.

required

Returns:

Type Description
Any

The output of the wrapped module.

load_state_dict

load_state_dict(
    state_dict: Mapping[str, Any],
    strict: bool = True,
    assign: bool = False,
) -> nn.modules.module._IncompatibleKeys

Distribute and broadcast state_dict tensors from rank 0 to all other ranks.

Parameters:

Name Type Description Default

state_dict

Mapping[str, Any]

A mapping containing parameters and persistent buffers. To avoid unnecessary reads, writes, and memory allocation, state_dict can and should be empty an empty dict for all processes except rank 0.

required

strict

bool

Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict method.

True

assign

bool

When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of nn.Parameters, for which the value from the module is preserved.

False

Returns:

Type Description
nn.modules.module._IncompatibleKeys

NamedTuple with missing_keys and unexpected_keys fields: * missing_keys is a list of str containing any keys that are expected by this module but missing from the provided state_dict. * unexpected_keys is a list of str containing the keys that are not expected by this module but present in the provided state_dict.

state_dict

state_dict(*args: Any, **kwargs: Any) -> dict[str, Any]

Return the state_dict of the wrapped module.

Parameters:

Name Type Description Default

args

Any

The positional arguments to pass to the wrapped module's state_dict method.

required

kwargs

Any

The keyword arguments to pass to the wrapped module's state_dict method.

required

Returns:

Type Description
dict[str, Any]

The state_dict of the wrapped module.

apply_fsdp2

apply_fsdp2(model: Module, mesh: DeviceMesh) -> None

Apply FSDP2 to the model.

Parameters:

Name Type Description Default

model

Module

The model to apply FSDP2 to.

required

mesh

DeviceMesh

The device mesh.

required

load_state_dict

load_state_dict(
    checkpoint_file: str,
    device_map: int
    | str
    | device
    | dict[str, int | str | device]
    | None = None,
) -> dict[str, Any]

Patch [accelerate.utils.modeling.load_state_dict][] to only load the state_dict if the global rank of the current process is 0.

Parameters:

Name Type Description Default

checkpoint_file

str

The file to load the state_dict from.

required

device_map

int | str | device | dict[str, int | str | device] | None

A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device.

None

Returns:

Type Description
dict[str, Any]

The loaded state_dict if the global rank is 0, otherwise an empty dict.