Skip to content

tensorflow

The tensorflow integrations module contains TensorFlow Adapter classes. These torch.nn.Module classes wrap a keras.Layer, allowing the training of Stained Glass Transforms implemented in PyTorch while using TensorFlow as the backend to compute the forward and backward model passes. For example usage, see the TensorFlow Adapter notebook.

Classes:

Name Description
TensorFlowAdapter

Wraps a keras.Layer as a torch.nn.Module.

VisionTensorFlowAdapter

Wraps a keras.Layer that accepts a 4-D image tensor as a torch.nn.Module.

TensorFlowAdapter

Bases: Module, Generic[LayerT]

Wraps a keras.Layer as a torch.nn.Module.

The wrapped model accepts torch.Tensors as input, and returns torch.Tensors as outputs, handling the conversion of and from tensorflow.Tensor. Additionally, the conversion of gradients from tensorflow.Tensor to torch.Tensor is also handled.

Parameters:

Name Type Description Default

tf_model

LayerT

The keras.Layer to wrap.

required

VisionTensorFlowAdapter

Bases: TensorFlowAdapter[LayerT]

Wraps a keras.Layer that accepts a 4-D image tensor as a torch.nn.Module.

The wrapped model automatically permutes the input torch.Tensor from the PyTorch dimension order of (batch_size, color_channels, height, width) to the TensorFlow convention of (batch_size, height, width, color_channels).

Parameters:

Name Type Description Default

tf_model

LayerT

The keras.Layer to wrap.

required

Examples:

>>> def mini_vision_model(
...     input_shape: tuple[int, int, int] = (32, 32, 3), num_classes: int = 2
... ) -> keras.Model:
...     inputs = keras.Input(shape=input_shape)
...     hidden = keras.layers.Conv2D(
...         filters=4,
...         kernel_size=(3, 3),
...         activation="relu",
...         padding="valid",
...     )(inputs)
...     hidden = keras.layers.Flatten()(hidden)
...     outputs = keras.layers.Dense(num_classes)(hidden)
...     return keras.Model(inputs, outputs)
>>>
>>> tf_model = mini_vision_model()
>>> torch_vision_tf_model = VisionTensorFlowAdapter(tf_model)

Forward pass using PyTorch tensors:

>>> parameter = nn.Parameter(torch.randn(3, 32, 32))
>>> image = torch.randn((batch_size := 2, 3, 32, 32))
>>> output = torch_vision_tf_model(image + parameter)
>>> output
tensor(...)

Backward pass using PyTorch loss:

>>> label = torch.randint(0, 2, (batch_size,))
>>> loss = torch.nn.functional.cross_entropy(output, label)
>>> loss.backward()
>>> parameter.grad
tensor(...)

tf2torch

tf2torch(tf_tensor: Tensor) -> torch.Tensor

Convert a tensorflow.Tensor to a torch.Tensor using DLPack.

Parameters:

Name Type Description Default

tf_tensor

Tensor

The tensorflow.Tensor to convert.

required

Returns:

Type Description
torch.Tensor

A torch.Tensor.

torch2tf

torch2tf(torch_tensor: Tensor) -> tf.Tensor

Convert a torch.Tensor to a tensorflow.Tensor using DLPack.

Parameters:

Name Type Description Default

torch_tensor

Tensor

The tensorflow.Tensor to convert.

required

Returns:

Type Description
tf.Tensor

A tensorflow.Tensor.

tree_tf2torch

tree_tf2torch(structure: T) -> T

Make a deep copy of structure, converting all nested tensorflow.Tensor to torch.Tensor.

Parameters:

Name Type Description Default

structure

T

The structure (dict, list, tuple) to traverse.

required

Returns:

Type Description
T

A deep copy of structure with all nested tensorflow.Tensor converted to torch.Tensor.

tree_torch2tf

tree_torch2tf(structure: T) -> T

Make a deep copy of structure, converting all nested torch.Tensor to tensorflow.Tensor.

Parameters:

Name Type Description Default

structure

T

The structure (dict, list, tuple) to traverse.

required

Returns:

Type Description
T

A deep copy of structure with all nested torch.Tensor converted to tensorflow.Tensor.