sg_model
SGModel
¶
Base class for all stained glass models.
__init__
¶
forward
¶
Delegate calls to the base model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
args |
Any
|
Inputs to the base model. |
required |
kwargs |
Dict[str, Any]
|
Keyword arguments to the base model. |
required |
Returns:
Type | Description |
---|---|
SGModelOutput[Any]
|
The result of the underlying model with noise added to the output of the base model's target layer. |
SGModelOutput
dataclass
¶
Bases: ModelOutput
, Generic[T]
The output of SGModel.forward()
.
__init_subclass__
¶
Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using torch.nn.parallel.DistributedDataParallel(static_graph=True)
with modules
that output ModelOutput
subclasses.
See: https://github.com/pytorch/pytorch/issues/106690.
to_tuple
¶
Convert self to a tuple containing all the attributes/keys that are not None
.
Returns:
Type | Description |
---|---|
tuple[Any, ...]
|
A tuple of all attributes/keys that are not |
losses_from_criterion_output
¶
losses_from_criterion_output(criterion_output: Tensor | dict[str, Tensor]) -> dict[str, torch.Tensor]
Format the output of a loss function into a dict
containing the key 'model_loss'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
criterion_output |
Tensor | dict[str, Tensor]
|
The output of a loss function. |
required |
Raises:
Type | Description |
---|---|
KeyError
|
If the loss function output is a |
Returns:
Type | Description |
---|---|
dict[str, torch.Tensor]
|
A |
sg_loss_wrapper
¶
sg_loss_wrapper(criterion: Callable[Concatenate[T, CriterionP], Tensor | dict[str, Tensor]]) -> Callable[Concatenate[SGModelOutput[T], CriterionP], dict[str, torch.Tensor]]
Wrap a loss function to accept an SGModelOutput
as its first argument.
Note
criterion
must either return a torch.Tensor or a dict
containing torch.Tensor and must necessarily include the key
'model_loss'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
criterion |
Callable[Concatenate[T, CriterionP], Tensor | dict[str, Tensor]]
|
The loss function to wrap. |
required |
Returns:
Type | Description |
---|---|
Callable[Concatenate[SGModelOutput[T], CriterionP], dict[str, torch.Tensor]]
|
A function that accepts an |