Skip to content

sg_model

Modules:

Name Description
output

Classes:

Name Description
SGModel

Base class for all stained glass models.

SGModelOutput

The output of SGModel.forward().

Functions:

Name Description
losses_from_criterion_output

Format the output of a loss function into a dict containing the key 'model_loss'.

sg_loss_wrapper

Wrap a loss function to accept an SGModelOutput as its first argument.

SGModel

Bases: Module, Generic[M]

Base class for all stained glass models.

Methods:

Name Description
__init__

Initialize the model.

forward

Delegate calls to the base model.

Attributes:

Name Type Description
input_shape tuple[int, ...]

The expected shape input to the base model.

input_shape property

input_shape: tuple[int, ...]

The expected shape input to the base model.

__init__

__init__(base_model: M, input_shape: tuple[int, ...]) -> None

Initialize the model.

Parameters:

Name Type Description Default

base_model

M

The base model.

required

input_shape

tuple[int, ...]

The expected shape input to the base model.

required

forward

forward(*args: Any, **kwargs: Any) -> SGModelOutput[Any]

Delegate calls to the base model.

Parameters:

Name Type Description Default

args

Any

Inputs to the base model.

required

kwargs

Dict[str, Any]

Keyword arguments to the base model.

required

Returns:

Type Description
SGModelOutput[Any]

The result of the underlying model with noise added to the output of the base model's target layer.

SGModelOutput dataclass

Bases: ModelOutput, Generic[T]

The output of SGModel.forward().

Methods:

Name Description
__init_subclass__

Register subclasses as pytree nodes.

to_tuple

Convert self to a tuple containing all the attributes/keys that are not None.

__init_subclass__

__init_subclass__() -> None

Register subclasses as pytree nodes.

This is necessary to synchronize gradients when using torch.nn.parallel.DistributedDataParallel(static_graph=True) with modules that output ModelOutput subclasses.

See: https://github.com/pytorch/pytorch/issues/106690.

to_tuple

to_tuple() -> tuple[Any, ...]

Convert self to a tuple containing all the attributes/keys that are not None.

Returns:

Type Description
tuple[Any, ...]

A tuple of all attributes/keys that are not None.

losses_from_criterion_output

losses_from_criterion_output(criterion_output: Tensor | dict[str, Tensor]) -> dict[str, torch.Tensor]

Format the output of a loss function into a dict containing the key 'model_loss'.

Parameters:

Name Type Description Default

criterion_output

Tensor | dict[str, Tensor]

The output of a loss function.

required

Raises:

Type Description
KeyError

If the loss function output is a dict and does not contain the key 'model_loss'.

Returns:

Type Description
dict[str, torch.Tensor]

A dict containing the key 'model_loss' and any other loss tensors returned by the criterion.

sg_loss_wrapper

sg_loss_wrapper(criterion: Callable[Concatenate[T, CriterionP], Tensor | dict[str, Tensor]]) -> Callable[Concatenate[SGModelOutput[T], CriterionP], dict[str, torch.Tensor]]

Wrap a loss function to accept an SGModelOutput as its first argument.

Note

criterion must either return a torch.Tensor or a dict containing torch.Tensor and must necessarily include the key 'model_loss'.

Parameters:

Name Type Description Default

criterion

Callable[Concatenate[T, CriterionP], Tensor | dict[str, Tensor]]

The loss function to wrap.

required

Returns:

Type Description
Callable[Concatenate[SGModelOutput[T], CriterionP], dict[str, torch.Tensor]]

A function that accepts an SGModelOutput as its first argument, and passes the base_model_output to the wrapped loss function.