Source code for ashpy.models.convolutional.discriminators

# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convolutional Discriminators."""
import typing
from typing import List, Tuple, Union

import tensorflow as tf
from ashpy.layers import Attention, InstanceNormalization
from ashpy.models.gans import ConvDiscriminator
from tensorflow import keras

__ALL__ = ["PatchDiscriminator", "MultiScaleDiscriminator"]

[docs]class PatchDiscriminator(ConvDiscriminator): """ Pix2Pix discriminator. The last layer is an image in which each pixels is the probability of being fake or real. Examples: .. testcode:: x = tf.ones((1, 64, 64, 3)) # instantiate the PathDiscriminator patchDiscriminator = PatchDiscriminator(input_res=64, min_res=16, kernel_size=5, initial_filters=64, filters_cap=512, ) # evaluate passing x output = patchDiscriminator(x) # the output shape is the same as the input shape print(output.shape) .. testoutput:: (1, 12, 12, 1) """
[docs] def __init__( self, input_res: int, min_res: int, kernel_size: int, initial_filters: int, filters_cap: int, use_dropout: bool = True, dropout_prob: float = 0.3, non_linearity: typing.Type[keras.layers.Layer] = keras.layers.LeakyReLU, normalization_layer: typing.Type[keras.layers.Layer] = InstanceNormalization, use_attention: bool = False, ): """ Patch Discriminator used by pix2pix. When min_res=1 this is the same as a standard fully convolutional discriminator. Args: input_res (int): Input Resolution. min_res (int): Minimum Resolution reached by the discriminator. kernel_size (int): Kernel Size used in Conv Layer. initial_filters (int): number of filters in the first convolutional layer. filters_cap (int): Maximum number of filters. use_dropout (bool): whether to use dropout. dropout_prob (float): probability of dropout. non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in the model. normalization_layer (:class:`tf.keras.layers.Layer`): normalization layer used in the model. use_attention (bool): whether to use attention. """ self.use_attention = use_attention self.layer_count = 0 self.normalization_layer = normalization_layer super().__init__( layer_spec_input_res=input_res, layer_spec_target_res=min_res, kernel_size=kernel_size, initial_filters=initial_filters, filters_cap=filters_cap, output_shape=None, use_dropout=use_dropout, dropout_prob=dropout_prob, non_linearity=non_linearity, ) # concatenate inputs on channel dimension self.concatenate = keras.layers.Concatenate(axis=-1) self.inputs = [1, 1]
[docs] def _add_final_block(self, output_shape): initializer = tf.random_normal_initializer(0.0, 0.02) # last layer mapping to one channel with Linear activation # Notice: The activation is linear since we use the BCE from logits self.model_layers.append(tf.keras.layers.ZeroPadding2D()) self.model_layers.append( tf.keras.layers.Conv2D( 512, self.kernel_size, strides=1, kernel_initializer=initializer, use_bias=False, ) ) self.model_layers.append(self.normalization_layer()) self.model_layers.append(self.non_linearity()) self.model_layers.append(tf.keras.layers.ZeroPadding2D()) self.model_layers.append( tf.keras.layers.Conv2D( 1, self.kernel_size, strides=1, kernel_initializer=initializer ) )
[docs] def call( self, inputs: Union[List, tf.Tensor], training=False, return_features=False ): """Forward pass of the PatchDiscriminator.""" return super().call( inputs=self.concatenate(inputs) if len(inputs) == 2 else inputs, training=training, return_features=return_features, )
[docs] def _add_building_block(self, filters, use_bn=False): """ Construct the core of the :py:obj:`tf.keras.Model`. The layers specified here get added to the :py:obj:`tf.keras.Model` multiple times consuming the hyper-parameters generated in the :func:`_get_layer_spec`. Args: filters (int): Number of filters to use for this iteration of the Building Block. """ initializer = tf.random_normal_initializer(0.0, 0.02) self.model_layers.extend( [ keras.layers.Conv2D( filters, self.kernel_size, strides=(2, 2), padding="same", kernel_initializer=initializer, ) ] ) if len(self.model_layers) > 1: self.model_layers.append(self.normalization_layer()) if self.use_dropout: self.model_layers.append(keras.layers.Dropout(self.dropout_prob)) self.model_layers.append(self.non_linearity()) if self.layer_count == 2 and self.use_attention: self.model_layers.append(Attention(filters)) self.layer_count += 1
[docs]class MultiScaleDiscriminator(tf.keras.Model): """ Multi-Scale discriminator. This discriminator architecture is composed by multiple discriminators working at different scales. Each discriminator is a :py:class:`ashpy.models.convolutional.discriminators.PatchDiscriminator`. Examples: .. testcode:: x = tf.ones((1, 256, 256, 3)) # instantiate the PathDiscriminator multiScaleDiscriminator = MultiScaleDiscriminator(input_res=256, min_res=16, kernel_size=5, initial_filters=64, filters_cap=512, n_discriminators=3 ) # evaluate passing x outputs = multiScaleDiscriminator(x) # the output shape is # the same as the input shape print(len(outputs)) for output in outputs: print(output.shape) .. testoutput:: 3 (1, 12, 12, 1) (1, 12, 12, 1) (1, 12, 12, 1) """
[docs] def __init__( self, input_res: int, min_res: int, kernel_size: int, initial_filters: int, filters_cap: int, use_dropout: bool = True, dropout_prob: float = 0.3, non_linearity: typing.Type[keras.layers.Layer] = keras.layers.LeakyReLU, normalization_layer: typing.Type[keras.layers.Layer] = InstanceNormalization, use_attention: bool = False, n_discriminators: int = 1, ): """ Multi Scale Discriminator. Different generator for different scales of the input image. Used by Pix2PixHD [1]_ . Args: input_res (int): input resolution min_res (int): minimum resolution reached by the discriminators kernel_size (int): kernel size of discriminators initial_filters (int): number of initial filters in the first layer of the discriminators filters_cap (int): maximum number of filters in the discriminators use_dropout (bool): whether to use dropout dropout_prob (float): probability of dropout non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in discriminators normalization_layer (:class:`tf.keras.layers.Layer`): normalization used by the discriminators use_attention (bool): whether to use attention n_discriminators (int): Number of discriminators .. [1] High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs """ super().__init__() self.n_discriminators = n_discriminators self.input_res = input_res self.min_res = min_res self.kernel_size = kernel_size self.initial_filters = initial_filters self.filters_cap = filters_cap self.dropout_prob = dropout_prob self.non_linearity = non_linearity self.use_dropout = use_dropout self.use_attention = use_attention self.normalization_layer = normalization_layer self.discriminators = [] # instantiate the discriminators for i in range(self.n_discriminators): self.discriminators.append( self.build_discriminator(int(input_res / (2 ** i))) ) # subsampling operation self.subsampling = tf.keras.layers.AvgPool2D() # hack in order to accept two inputs self.inputs = [1, 1]
[docs] def build_discriminator(self, input_res) -> ConvDiscriminator: """ Build a single discriminator using parameters defined in this object. Args: input_res: input resolution of the discriminator. Returns: A Discriminator (PatchDiscriminator). """ return PatchDiscriminator( input_res=input_res, min_res=self.min_res, kernel_size=self.kernel_size, initial_filters=self.initial_filters, filters_cap=self.filters_cap, use_dropout=self.use_dropout, dropout_prob=self.dropout_prob, non_linearity=self.non_linearity, use_attention=self.use_attention, normalization_layer=self.normalization_layer, )
[docs] def call( self, inputs: Union[List, tf.Tensor], training=True, return_features=False ) -> Union[List[tf.Tensor], Tuple[List[tf.Tensor], List[tf.Tensor]]]: """ Forward pass of the Multi Scale Discriminator. Args: inputs (:py:class:`tf.Tensor`): input tensor. training (bool): whether is training or not. return_features (bool): whether to return features or not. Returns: ([:py:class:`tf.Tensor`]): A List of Tensors containing the value of D_i for each input. ([:py:class:`tf.Tensor`]): A List of features for each discriminator if `return_features`. """ is_conditioned = isinstance(inputs, list) if is_conditioned: ( fake_or_real, condition, ) = inputs # inputs is a tuple containing the generated images and the conditions else: fake_or_real = inputs condition = None outs = [] features = [] fake_or_real_i = fake_or_real condition_i = condition for i, discriminator in enumerate(self.discriminators): # compute value of the i-th discriminator out, feat = discriminator( [fake_or_real_i, condition_i] if condition_i is not None else fake_or_real_i, training=training, return_features=True, ) # append output and features outs.append(out) features.extend(feat) # reduce input size if i != len(self.discriminators) - 1: fake_or_real_i = self.subsampling(fake_or_real_i) condition_i = ( self.subsampling(condition_i) if condition_i is not None else None ) # handle output values if return_features: return outs, features return outs