Skip to content

sg_model

SGModel

Bases: Module, Generic[M]

Base class for all stained glass models.

input_shape property

input_shape: tuple[int, ...]

The expected shape input to the base model.

__init__

__init__(base_model: M, input_shape: tuple[int, ...]) -> None

Initialize the model.

Parameters:

Name Type Description Default
base_model M

The base model.

required
input_shape tuple[int, ...]

The expected shape input to the base model.

required

forward

forward(*args: Any, **kwargs: Any) -> SGModelOutput[Any]

Delegate calls to the base model.

Parameters:

Name Type Description Default
args Any

Inputs to the base model.

required
kwargs Dict[str, Any]

Keyword arguments to the base model.

required

Returns:

Type Description
SGModelOutput[Any]

The result of the underlying model with noise added to the output of the base model's target layer.

SGModelOutput dataclass

Bases: ModelOutput, Generic[T]

The output of SGModel.forward().

__init_subclass__

__init_subclass__() -> None

Register subclasses as pytree nodes.

This is necessary to synchronize gradients when using torch.nn.parallel.DistributedDataParallel(static_graph=True) with modules that output ModelOutput subclasses.

See: https://github.com/pytorch/pytorch/issues/106690.

to_tuple

to_tuple() -> tuple[Any, ...]

Convert self to a tuple containing all the attributes/keys that are not None.

Returns:

Type Description
tuple[Any, ...]

A tuple of all attributes/keys that are not None.

losses_from_criterion_output

losses_from_criterion_output(criterion_output: Tensor | dict[str, Tensor]) -> dict[str, torch.Tensor]

Format the output of a loss function into a dict containing the key 'model_loss'.

Parameters:

Name Type Description Default
criterion_output Tensor | dict[str, Tensor]

The output of a loss function.

required

Raises:

Type Description
KeyError

If the loss function output is a dict and does not contain the key 'model_loss'.

Returns:

Type Description
dict[str, torch.Tensor]

A dict containing the key 'model_loss' and any other loss tensors returned by the criterion.

sg_loss_wrapper

sg_loss_wrapper(criterion: Callable[Concatenate[T, CriterionP], Tensor | dict[str, Tensor]]) -> Callable[Concatenate[SGModelOutput[T], CriterionP], dict[str, torch.Tensor]]

Wrap a loss function to accept an SGModelOutput as its first argument.

Note

criterion must either return a torch.Tensor or a dict containing torch.Tensor and must necessarily include the key 'model_loss'.

Parameters:

Name Type Description Default
criterion Callable[Concatenate[T, CriterionP], Tensor | dict[str, Tensor]]

The loss function to wrap.

required

Returns:

Type Description
Callable[Concatenate[SGModelOutput[T], CriterionP], dict[str, torch.Tensor]]

A function that accepts an SGModelOutput as its first argument, and passes the base_model_output to the wrapped loss function.