tensorflow
The tensorflow
integrations module contains TensorFlow Adapter classes. These torch.nn.Module
classes wrap a keras.Layer
,
allowing the training of Stained Glass Transforms implemented in PyTorch while using TensorFlow as the backend to compute the forward and
backward model passes. For example usage, see the TensorFlow Adapter notebook.
Classes:
Name | Description |
---|---|
TensorFlowAdapter |
Wraps a |
VisionTensorFlowAdapter |
Wraps a |
TensorFlowAdapter
¶
Bases: Module
, Generic[LayerT]
Wraps a keras.Layer
as a torch.nn.Module
.
The wrapped model accepts torch.Tensor
s as input, and returns torch.Tensor
s as outputs, handling the conversion of and from
tensorflow.Tensor
. Additionally, the conversion of gradients from tensorflow.Tensor
to torch.Tensor
is also handled.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
LayerT
|
The |
required |
VisionTensorFlowAdapter
¶
Bases: TensorFlowAdapter[LayerT]
Wraps a keras.Layer
that accepts a 4-D image tensor as a torch.nn.Module
.
The wrapped model automatically permutes the input torch.Tensor
from the PyTorch dimension order of (batch_size, color_channels,
height, width) to the TensorFlow convention of (batch_size, height, width, color_channels).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
LayerT
|
The |
required |
Examples:
>>> def mini_vision_model(
... input_shape: tuple[int, int, int] = (32, 32, 3), num_classes: int = 2
... ) -> keras.Model:
... inputs = keras.Input(shape=input_shape)
... hidden = keras.layers.Conv2D(
... filters=4,
... kernel_size=(3, 3),
... activation="relu",
... padding="valid",
... )(inputs)
... hidden = keras.layers.Flatten()(hidden)
... outputs = keras.layers.Dense(num_classes)(hidden)
... return keras.Model(inputs, outputs)
>>>
>>> tf_model = mini_vision_model()
>>> torch_vision_tf_model = VisionTensorFlowAdapter(tf_model)
Forward pass using PyTorch tensors:
>>> parameter = nn.Parameter(torch.randn(3, 32, 32))
>>> image = torch.randn((batch_size := 2, 3, 32, 32))
>>> output = torch_vision_tf_model(image + parameter)
>>> output
tensor(...)
Backward pass using PyTorch loss:
>>> label = torch.randint(0, 2, (batch_size,))
>>> loss = torch.nn.functional.cross_entropy(output, label)
>>> loss.backward()
>>> parameter.grad
tensor(...)
tf2torch
¶
tf2torch(tf_tensor: Tensor) -> torch.Tensor
Convert a tensorflow.Tensor
to a torch.Tensor
using DLPack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
The |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
A |
torch2tf
¶
torch2tf(torch_tensor: Tensor) -> tf.Tensor
tree_tf2torch
¶
tree_tf2torch(structure: T) -> T
Make a deep copy of structure
, converting all nested tensorflow.Tensor
to torch.Tensor
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
T
|
The structure ( |
required |
Returns:
Type | Description |
---|---|
T
|
A deep copy of structure with all nested |
tree_torch2tf
¶
tree_torch2tf(structure: T) -> T
Make a deep copy of structure
, converting all nested torch.Tensor
to tensorflow.Tensor
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
T
|
The structure ( |
required |
Returns:
Type | Description |
---|---|
T
|
A deep copy of structure with all nested |