# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UNET implementations."""
import typing
import tensorflow as tf
from ashpy.layers import Attention, InstanceNormalization
from ashpy.models.convolutional.interfaces import Conv2DInterface
from tensorflow import keras
__ALL__ = ["UNet", "SUNet", "FUNet"]
[docs]class UNet(Conv2DInterface):
"""
UNet Architecture.
Architecture similar to the one found in "Image-to-Image Translation
with Conditional Adversarial Nets" [1]_.
Originally proposed in "U-Net: Convolutional Networks for Biomedical Image Segmentation" [2]_.
Examples:
* Direct Usage:
.. testcode::
x = tf.ones((1, 512, 512, 3))
u_net = UNet(input_res = 512,
min_res=4,
kernel_size=4,
initial_filters=64,
filters_cap=512,
channels=3)
y = u_net(x)
print(y.shape)
print(len(u_net.trainable_variables)>0)
.. testoutput::
(1, 512, 512, 3)
True
.. [1] Image-to-Image Translation with Conditional Adversarial Nets -
https://arxiv.org/abs/1611.07004
.. [2] U-Net: Convolutional Networks for Biomedical Image Segmentation -
https://arxiv.org/abs/1505.04597
"""
[docs] def __init__(
self,
input_res: int,
min_res: int,
kernel_size: int,
initial_filters: int,
filters_cap: int,
channels: int,
use_dropout_encoder: bool = True,
use_dropout_decoder: bool = True,
dropout_prob: float = 0.3,
encoder_non_linearity: typing.Type[keras.layers.Layer] = keras.layers.LeakyReLU,
decoder_non_linearity: typing.Type[keras.layers.Layer] = keras.layers.ReLU,
normalization_layer: typing.Type[keras.layers.Layer] = InstanceNormalization,
last_activation: keras.activations = keras.activations.tanh,
use_attention: bool = False,
):
"""
Initialize the UNet.
Args:
input_res: input resolution.
min_res: minimum resolution reached after decode.
kernel_size: kernel size used in the network.
initial_filters: number of filter of the initial convolution.
filters_cap: maximum number of filters.
channels: number of output channels.
use_dropout_encoder: whether to use dropout in the encoder module.
use_dropout_decoder: whether to use dropout in the decoder module.
dropout_prob: probability of dropout.
encoder_non_linearity: non linearity of encoder.
decoder_non_linearity: non linearity of decoder.
last_activation: last activation function, tanh or softmax (for semantic images).
use_attention: whether to use attention.
"""
super().__init__()
# layer specification
self.use_dropout_encoder = use_dropout_encoder
self.use_dropout_decoder = use_dropout_decoder
self.dropout_probability = dropout_prob
self.encoder_non_linearity = encoder_non_linearity
self.decoder_non_linearity = decoder_non_linearity
self.kernel_size = kernel_size
self.use_attention = use_attention
self.normalization = normalization_layer
# encoder layers is a list of list, each list is a "block",
# this makes easy the creation of decoder
self.encoder_layers = []
self.decoder_layers = []
self.concat_layers = []
# ########### Encoder creation
encoder_layers_spec = self._get_layer_spec(
initial_filters, filters_cap, input_res, min_res
)
# from generator to list
encoder_layers_spec = [x for x in encoder_layers_spec]
decoder_layer_spec = []
for i, filters in enumerate(encoder_layers_spec):
decoder_layer_spec.insert(0, filters)
block = self.get_encoder_block(
filters,
use_bn=(i not in (0, len(encoder_layers_spec) - 1)),
use_attention=i == 2,
)
self.encoder_layers.append(block)
# ############## Decoder creation
decoder_layer_spec = decoder_layer_spec[1:]
for i, filters in enumerate(decoder_layer_spec):
self.concat_layers.append(keras.layers.Concatenate())
block = self.get_decoder_block(
filters, use_dropout=(i < 3), use_attention=i == 5
)
self.decoder_layers.append(block)
# final layer
initializer = tf.random_normal_initializer(0.0, 0.02)
self.final_layer = keras.layers.Conv2DTranspose(
channels,
self.kernel_size,
strides=(2, 2),
padding="same",
activation=last_activation,
kernel_initializer=initializer,
)
def _get_block(
self,
filters,
conv_layer=None,
use_bn=True,
use_dropout=False,
non_linearity=keras.layers.LeakyReLU,
use_attention=False,
):
initializer = tf.random_normal_initializer(0.0, 0.02)
# Conv2D
block = [
conv_layer(
filters,
self.kernel_size,
strides=(2, 2),
padding="same",
use_bias=False,
kernel_initializer=initializer,
)
]
# Batch normalization
if use_bn:
block.append(self.normalization())
# dropout
if use_dropout:
block.append(keras.layers.Dropout(self.dropout_probability))
# Non linearity
block.append(non_linearity())
# attention
if use_attention:
block.append(Attention(filters))
return block
[docs] def get_encoder_block(self, filters, use_bn=True, use_attention=False):
"""
Return a block to be used in the encoder part of the UNET.
Args:
filters: number of filters.
use_bn: whether to use batch normalization.
use_attention: whether to use attention.
Returns:
A block to be used in the encoder part.
"""
return self._get_block(
filters,
conv_layer=keras.layers.Conv2D,
use_bn=use_bn,
use_dropout=self.use_dropout_encoder,
non_linearity=self.encoder_non_linearity,
use_attention=use_attention and self.use_attention,
)
[docs] def get_decoder_block(
self, filters, use_bn=True, use_dropout=False, use_attention=False
):
"""
Return a block to be used in the decoder part of the UNET.
Args:
filters: number of filters
use_bn: whether to use batch normalization
use_dropout: whether to use dropout
use_attention: whether to use attention
Returns:
A block to be used in the decoder part
"""
return self._get_block(
filters,
conv_layer=keras.layers.Conv2DTranspose,
use_bn=use_bn,
use_dropout=self.use_dropout_decoder and use_dropout,
non_linearity=self.decoder_non_linearity,
use_attention=use_attention and self.use_attention,
)
# @tf.function(
# input_signature=[tf.TensorSpec(shape=[None, 512, 512, 1], dtype=tf.float32)]
# )
[docs] def call(self, inputs, training=False):
"""Forward pass of the UNet model."""
encoder_layer_eval = []
x = inputs
for block in self.encoder_layers:
for layer in block:
if isinstance(
layer, (keras.layers.BatchNormalization, keras.layers.Dropout)
):
x = layer(x, training=training)
else:
x = layer(x)
encoder_layer_eval.append(x)
encoder_layer_eval = encoder_layer_eval[:-1]
for i, block in enumerate(self.decoder_layers):
for layer in block:
if isinstance(
layer, (keras.layers.BatchNormalization, keras.layers.Dropout)
):
x = layer(x, training=training)
else:
x = layer(x)
x = self.concat_layers[i]([x, encoder_layer_eval[-1 - i]])
x = self.final_layer(x)
return x
[docs]class SUNet(UNet):
"""Semantic UNet."""
[docs] def __init__(
self,
input_res,
min_res,
kernel_size,
initial_filters,
filters_cap,
channels, # number of classes
use_dropout_encoder=True,
use_dropout_decoder=True,
dropout_prob=0.3,
encoder_non_linearity=keras.layers.LeakyReLU,
decoder_non_linearity=keras.layers.ReLU,
use_attention=False,
):
"""Build the Semantic UNet model."""
super().__init__(
input_res,
min_res,
kernel_size,
initial_filters,
filters_cap,
channels,
use_dropout_encoder,
use_dropout_decoder,
dropout_prob,
encoder_non_linearity,
decoder_non_linearity,
last_activation=keras.activations.softmax,
use_attention=use_attention,
)
[docs]def FUNet(
input_res,
min_res,
kernel_size,
initial_filters,
filters_cap,
channels,
input_channels=3,
use_dropout_encoder=True,
use_dropout_decoder=True,
dropout_prob=0.3,
encoder_non_linearity=keras.layers.LeakyReLU,
decoder_non_linearity=keras.layers.ReLU,
last_activation=keras.activations.tanh, # tanh or softmax (for semantic images)
use_attention=False,
):
"""Functional UNET Implementation."""
# ########### Encoder creation
encoder_layers_spec = Conv2DInterface._get_layer_spec(
initial_filters, filters_cap, input_res, min_res
)
encoder_layers_spec = [x for x in encoder_layers_spec]
normalization = InstanceNormalization
def get_block(
kernel_size,
filters,
conv_layer,
use_bn,
use_dropout,
non_linearity,
use_attention,
dropout_probability,
):
initializer = tf.random_normal_initializer(0.0, 0.02)
# Conv2D
block = [
conv_layer(
filters,
kernel_size,
strides=(2, 2),
padding="same",
use_bias=False,
kernel_initializer=initializer,
)
]
# Batch normalization
if use_bn:
block.append(normalization())
# dropout
if use_dropout:
block.append(keras.layers.Dropout(dropout_probability))
# Non linearity
block.append(non_linearity())
# attention
if use_attention:
block.append(Attention(filters))
return block
decoder_layer_spec = []
encoder_layers = []
concat_layers = []
decoder_layers = []
for i, filters in enumerate(encoder_layers_spec):
decoder_layer_spec.insert(0, filters)
block = get_block(
kernel_size,
filters,
conv_layer=keras.layers.Conv2D,
use_bn=(i not in (0, len(encoder_layers_spec) - 1)),
use_dropout=use_dropout_encoder,
non_linearity=encoder_non_linearity,
use_attention=(i == 2 and use_attention),
dropout_probability=dropout_prob,
)
encoder_layers.append(block)
# ############## Decoder creation
decoder_layer_spec = decoder_layer_spec[1:]
for i, filters in enumerate(decoder_layer_spec):
concat_layers.append(keras.layers.Concatenate())
block = get_block(
kernel_size,
filters,
conv_layer=keras.layers.Conv2DTranspose,
use_bn=(i != 0),
use_dropout=(i < 3) and use_dropout_decoder,
non_linearity=decoder_non_linearity,
use_attention=(i == 5 and use_attention),
dropout_probability=dropout_prob,
)
decoder_layers.append(block)
# final layer
initializer = tf.random_normal_initializer(0.0, 0.02)
final_layer = keras.layers.Conv2DTranspose(
channels,
kernel_size,
strides=(2, 2),
padding="same",
activation=last_activation,
kernel_initializer=initializer,
)
inputs = tf.keras.layers.Input(shape=[input_res, input_res, input_channels])
x = inputs
skips = []
for block in encoder_layers:
for layer in block:
x = layer(x)
skips.append(x)
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for block, skip in zip(decoder_layers, skips):
for layer in block:
x = layer(x)
x = tf.keras.layers.Concatenate()([x, skip])
x = final_layer(x)
return tf.keras.Model(inputs=inputs, outputs=x)