sharding
Classes:
| Name | Description | 
|---|---|
| DistModule | A module that wraps another module and distributes its state_dict tensors across multiple processes when loading. | 
Functions:
| Name | Description | 
|---|---|
| load_state_dict | Patch [ | 
| shard_broadcasted_state_dict | Receive the broadcasted, unsharded state_dict tensors from rank 0 and shard them for the current rank. | 
    
              Bases: Module
A module that wraps another module and distributes its state_dict tensors across multiple processes when loading.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Module | The module to wrap and distribute the state_dict tensors for. | required | 
Methods:
| Name | Description | 
|---|---|
| __getattr__ | Delegate attribute access to the wrapped module. | 
| __setattr__ | Delegate attribute setting to the wrapped module. | 
| forward | Forward the inputs to the wrapped module. | 
| load_state_dict | Distribute and broadcast state_dict tensors from rank 0 to all other ranks. | 
| state_dict | Return the state_dict of the wrapped module. | 
    
    
    
load_state_dict(
    state_dict: Mapping[str, Any],
    strict: bool = True,
    assign: bool = False,
) -> nn.modules.module._IncompatibleKeys
Distribute and broadcast state_dict tensors from rank 0 to all other ranks.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Mapping[str, Any] | A mapping containing parameters and persistent buffers. To avoid unnecessary reads, writes, and memory allocation,
 | required | 
|                    | bool | Whether to strictly enforce that the keys in  | True | 
|                    | bool | When  | False | 
Returns:
| Type | Description | 
|---|---|
| nn.modules.module._IncompatibleKeys | 
 | 
    Return the state_dict of the wrapped module.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Any | The positional arguments to pass to the wrapped module's  | required | 
|                    | Any | The keyword arguments to pass to the wrapped module's  | required | 
Returns:
| Type | Description | 
|---|---|
| dict[str, Any] | The state_dict of the wrapped module. | 
broadcast_and_shard_state_dict(
    state_dict: Mapping[str, Any],
    meta_sharded_state_dict: Mapping[str, Any],
    sharded_state_dict: dict[str, Any],
) -> None
Broadcast the unsharded tensors in state_dict to all other processes in the group and shard the tensors for rank 0.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Mapping[str, Any] | The loaded state_dict containing the unsharded tensors to broadcast. | required | 
|                    | Mapping[str, Any] | The state_dict containing the sharded meta tensors; used to determine how to distribute the tensors. | required | 
|                    | dict[str, Any] | The state_dict to store the sharded, materialized tensors in. This state_dict can be loaded by an  | required | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If called by any process other than rank 0. | 
load_state_dict(
    checkpoint_file: str,
    device_map: int
    | str
    | device
    | dict[str, int | str | device]
    | None = None,
) -> dict[str, Any]
Patch [accelerate.utils.modeling.load_state_dict][] to only load the state_dict if the
global rank of the current process is 0.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | str | The file to load the state_dict from. | required | 
|                    | int | str | device | dict[str, int | str | device] | None | A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. | None | 
Returns:
| Type | Description | 
|---|---|
| dict[str, Any] | The loaded state_dict if the global rank is 0, otherwise an empty  | 
shard_broadcasted_state_dict(
    meta_sharded_state_dict: Mapping[str, Any],
    sharded_state_dict: dict[str, Any],
) -> None
Receive the broadcasted, unsharded state_dict tensors from rank 0 and shard them for the current rank.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
|                    | Mapping[str, Any] | The state_dict containing the sharded meta tensors; used to determine how to distribute the tensors. | required | 
|                    | dict[str, Any] | The state_dict to store the sharded, materialized tensors in. This state_dict can be loaded by an  | required | 
Raises:
| Type | Description | 
|---|---|
| ValueError | If called by the rank 0 process. |