sharding
Classes:
Name | Description |
---|---|
DistModule |
A module that wraps another module and distributes its state_dict tensors across multiple processes when loading. |
Functions:
Name | Description |
---|---|
load_state_dict |
Patch [ |
shard_broadcasted_state_dict |
Receive the broadcasted, unsharded state_dict tensors from rank 0 and shard them for the current rank. |
DistModule
¶
Bases: Module
A module that wraps another module and distributes its state_dict tensors across multiple processes when loading.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Module
|
The module to wrap and distribute the state_dict tensors for. |
required |
Methods:
Name | Description |
---|---|
__getattr__ |
Delegate attribute access to the wrapped module. |
__setattr__ |
Delegate attribute setting to the wrapped module. |
forward |
Forward the inputs to the wrapped module. |
load_state_dict |
Distribute and broadcast state_dict tensors from rank 0 to all other ranks. |
state_dict |
Return the state_dict of the wrapped module. |
__getattr__
¶
__setattr__
¶
forward
¶
load_state_dict
¶
load_state_dict(
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
) -> nn.modules.module._IncompatibleKeys
Distribute and broadcast state_dict tensors from rank 0 to all other ranks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Mapping[str, Any]
|
A mapping containing parameters and persistent buffers. To avoid unnecessary reads, writes, and memory allocation,
|
required |
|
bool
|
Whether to strictly enforce that the keys in |
True
|
|
bool
|
When |
False
|
Returns:
Type | Description |
---|---|
nn.modules.module._IncompatibleKeys
|
|
state_dict
¶
Return the state_dict of the wrapped module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Any
|
The positional arguments to pass to the wrapped module's |
required |
|
Any
|
The keyword arguments to pass to the wrapped module's |
required |
Returns:
Type | Description |
---|---|
dict[str, Any]
|
The state_dict of the wrapped module. |
broadcast_and_shard_state_dict
¶
broadcast_and_shard_state_dict(
state_dict: Mapping[str, Any],
meta_sharded_state_dict: Mapping[str, Any],
sharded_state_dict: dict[str, Any],
) -> None
Broadcast the unsharded tensors in state_dict
to all other processes in the group and shard the tensors for rank 0.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Mapping[str, Any]
|
The loaded state_dict containing the unsharded tensors to broadcast. |
required |
|
Mapping[str, Any]
|
The state_dict containing the sharded meta tensors; used to determine how to distribute the tensors. |
required |
|
dict[str, Any]
|
The state_dict to store the sharded, materialized tensors in. This state_dict can be loaded by an |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If called by any process other than rank 0. |
load_state_dict
¶
load_state_dict(
checkpoint_file: str,
device_map: int
| str
| device
| dict[str, int | str | device]
| None = None,
) -> dict[str, Any]
Patch [accelerate.utils.modeling.load_state_dict
][] to only load the state_dict if the
global rank of the current process is 0
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
str
|
The file to load the state_dict from. |
required |
|
int | str | device | dict[str, int | str | device] | None
|
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. |
None
|
Returns:
Type | Description |
---|---|
dict[str, Any]
|
The loaded state_dict if the global rank is 0, otherwise an empty |
shard_broadcasted_state_dict
¶
shard_broadcasted_state_dict(
meta_sharded_state_dict: Mapping[str, Any],
sharded_state_dict: dict[str, Any],
) -> None
Receive the broadcasted, unsharded state_dict tensors from rank 0 and shard them for the current rank.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Mapping[str, Any]
|
The state_dict containing the sharded meta tensors; used to determine how to distribute the tensors. |
required |
|
dict[str, Any]
|
The state_dict to store the sharded, materialized tensors in. This state_dict can be loaded by an |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If called by the rank 0 process. |