base
BaseNoiseLayer
¶
Bases: Module
, Generic[EstimatorModuleT, ParameterizationT, OptionalMaskerT]
Base Class for Stained Glass Transform Layers.
input_shape
property
¶
The shape of the expected input including its batch dimension.
mask
property
writable
¶
mask: Tensor | None
The mask to apply calculated from parameters of the stochastic transformation computed during the most recent call to forward.
mean
property
writable
¶
mean: Tensor
The means of the stochastic transformation computed during the most recent call to forward.
std
property
writable
¶
std: Tensor
The standard deviations of the stochastic transformation computed during the most recent call to forward.
__call__
¶
Stochastically transform the input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
Tensor
|
The input to transform. |
required |
noise_mask |
Tensor | None
|
An optional mask that selects the elements of |
None
|
**kwargs |
Any
|
Additional keyword arguments to the estimator modules. |
required |
__init__
¶
__init__(input_shape: tuple[int, ...], seed: int | None, mean_estimator: Estimator[EstimatorModuleT, None, None], std_estimator: Estimator[EstimatorModuleT, ParameterizationT, OptionalMaskerT]) -> None
Initialize necessary input_shape
parameter to use Stained Glass Transform layers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_shape |
tuple[int, ...]
|
Shape of given inputs. The first dimension may be -1, meaning variable batch size. |
required |
seed |
int | None
|
Seed for the random number generator used to generate the stochastic transformation. If |
required |
mean_estimator |
Estimator[EstimatorModuleT, None, None]
|
The estimator to use to estimate the mean of the stochastic transformation. |
required |
std_estimator |
Estimator[EstimatorModuleT, ParameterizationT, OptionalMaskerT]
|
The estimator to use to estimate the standard deviation and optional input mask of the stochastic transformation. |
required |
__init_subclass__
¶
Set the default dtype to torch.float32
inside all subclass __init__
methods.
__setstate__
¶
Restore from a serialized copy of self.__dict__
.
forward
abstractmethod
¶
Transform the input data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
Tensor
|
The input to transform. |
required |
noise_mask |
Tensor | None
|
An optional mask that selects the elements of |
None
|
**kwargs |
Any
|
Additional keyword arguments to the estimator modules. |
required |
Returns:
Type | Description |
---|---|
NoiseLayerOutput
|
The transformed input data. |
get_applied_transform_components_factory
¶
Create a function that returns the elements of the transform components ('mean'
and 'std'
) applied during the most recent
forward pass.
Specifically, the applied elements are those selected by the noise mask (if supplied) and standard deviation mask (if
std_estimator.masker is not None
). If no masks are used, all elements are returned.
The applied transform components are returned flattened.
This function is intended to be used to log histograms of the transform components.
Returns:
Type | Description |
---|---|
Callable[[], dict[str, torch.Tensor]]
|
A function that returns the the elements of the transform components applied during the most recent forward pass. |
Examples:
>>> from torch import nn
>>> from stainedglass_core import model as sg_model, noise_layer as sg_noise_layer
>>> base_model = nn.Linear(20, 2)
>>> noisy_model = sg_model.NoisyModel(
... sg_noise_layer.CloakNoiseLayer1,
... base_model,
... input_shape=(-1, 20),
... )
>>> get_applied_transform_components = (
... noisy_model.noise_layer.get_applied_transform_components_factory()
... )
>>> input = torch.ones(1, 20)
>>> noise_mask = torch.tensor(5 * [False] + 15 * [True])
>>> output = base_model(input, noise_mask=noise_mask)
>>> applied_transform_components = get_applied_transform_components()
>>> applied_transform_components
{'mean': tensor(...), 'std': tensor(...)}
>>> {
... component_name: component.shape
... for component_name, component in applied_transform_components.items()
... }
{'mean': torch.Size([15]), 'std': torch.Size([15])}
get_transformed_output_factory
¶
Create a function that returns the transformed output from the most recent forward pass.
If super batching is active, only the transformed half of the super batch output is returned.
Returns:
Type | Description |
---|---|
Callable[[], torch.Tensor]
|
A function that returns the transformed output from the most recent forward pass. |
Examples:
>>> from stainedglass_core import noise_layer as sg_noise_layer
>>> noise_layer = sg_noise_layer.CloakNoiseLayer1(input_shape=(-1, 3, 32, 32))
>>> get_transformed_output = noise_layer.get_transformed_output_factory()
>>> input = torch.ones(2, 3, 32, 32)
>>> output = noise_layer(input)
>>> transformed_output = get_transformed_output()
>>> assert output.output.equal(transformed_output)
initial_seed
¶
Return the initial seed of the CPU device's random number generator.
manual_seed
¶
manual_seed(seed: int) -> None
Seed each of the random number generators.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seed |
int
|
The seed to set. |
required |
seed
¶
Seed each of the random number generators using a non-deterministic random number.
NoiseLayerOutput
dataclass
¶
Bases: ModelOutput
The output of BaseNoiseLayer.forward()
.
__init_subclass__
¶
Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using torch.nn.parallel.DistributedDataParallel(static_graph=True)
with modules
that output ModelOutput
subclasses.
See: https://github.com/pytorch/pytorch/issues/106690.
to_tuple
¶
Convert self to a tuple containing all the attributes/keys that are not None
.
Returns:
Type | Description |
---|---|
tuple[Any, ...]
|
A tuple of all attributes/keys that are not |