output
ModelOutput
dataclass
¶
Bases: dict
Base class for all model outputs as dataclass. Has a __getitem__
that allows indexing by integer or slice (like a tuple) or strings
(like a dictionary) that will ignore the None
attributes. Otherwise behaves like a regular python dictionary.
Note
You can't unpack a ModelOutput
directly. Use the to_tuple
method to convert it to a tuple before.
__init_subclass__
¶
Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using torch.nn.parallel.DistributedDataParallel(static_graph=True)
with modules
that output ModelOutput
subclasses.
See: https://github.com/pytorch/pytorch/issues/106690.
to_tuple
¶
Convert self to a tuple containing all the attributes/keys that are not None
.
Returns:
Type | Description |
---|---|
tuple[Any, ...]
|
A tuple of all attributes/keys that are not |