distributed
Modules:
Name | Description |
---|---|
fsdp |
|
sharding |
|
Classes:
Name | Description |
---|---|
DistModule |
A module that wraps another module and distributes its state_dict tensors across multiple processes when loading. |
Functions:
Name | Description |
---|---|
apply_fsdp2 |
Apply FSDP2 to the model. |
load_state_dict |
Patch [ |
DistModule
¶
Bases: Module
A module that wraps another module and distributes its state_dict tensors across multiple processes when loading.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The module to wrap and distribute the state_dict tensors for. |
required |
Methods:
Name | Description |
---|---|
__getattr__ |
Delegate attribute access to the wrapped module. |
__setattr__ |
Delegate attribute setting to the wrapped module. |
forward |
Forward the inputs to the wrapped module. |
load_state_dict |
Distribute and broadcast state_dict tensors from rank 0 to all other ranks. |
state_dict |
Return the state_dict of the wrapped module. |
__getattr__
¶
__setattr__
¶
forward
¶
load_state_dict
¶
load_state_dict(
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
) -> nn.modules.module._IncompatibleKeys
Distribute and broadcast state_dict tensors from rank 0 to all other ranks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Mapping[str, Any]
|
A mapping containing parameters and persistent buffers. To avoid unnecessary reads, writes, and memory allocation,
|
required |
|
bool
|
Whether to strictly enforce that the keys in |
True
|
|
bool
|
When |
False
|
Returns:
Type | Description |
---|---|
nn.modules.module._IncompatibleKeys
|
|
state_dict
¶
Return the state_dict of the wrapped module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Any
|
The positional arguments to pass to the wrapped module's |
required |
|
Any
|
The keyword arguments to pass to the wrapped module's |
required |
Returns:
Type | Description |
---|---|
dict[str, Any]
|
The state_dict of the wrapped module. |
apply_fsdp2
¶
apply_fsdp2(model: Module, mesh: DeviceMesh) -> None
Apply FSDP2 to the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The model to apply FSDP2 to. |
required |
|
DeviceMesh
|
The device mesh. |
required |
load_state_dict
¶
load_state_dict(
checkpoint_file: str,
device_map: int
| str
| device
| dict[str, int | str | device]
| None = None,
) -> dict[str, Any]
Patch [accelerate.utils.modeling.load_state_dict
][] to only load the state_dict if the
global rank of the current process is 0
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
str
|
The file to load the state_dict from. |
required |
|
int | str | device | dict[str, int | str | device] | None
|
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. |
None
|
Returns:
Type | Description |
---|---|
dict[str, Any]
|
The loaded state_dict if the global rank is 0, otherwise an empty |