"""Convolutional Discriminators."""
import typing
from typing import List, Tuple, Union

import tensorflow as tf
from ashpy.layers import Attention, InstanceNormalization
from ashpy.models.gans import ConvDiscriminator
from tensorflow import keras

__ALL__ = ["PatchDiscriminator", "MultiScaleDiscriminator"]

[docs]class PatchDiscriminator(ConvDiscriminator): """ Pix2Pix discriminator. The last layer is an image in which each pixels is the probability of being fake or real. Examples: .. testcode:: x = tf.ones((1, 64, 64, 3)) # instantiate the PathDiscriminator patchDiscriminator = PatchDiscriminator(input_res=64, min_res=16, kernel_size=5, initial_filters=64, filters_cap=512, ) # evaluate passing x output = patchDiscriminator(x) # the output shape is the same as the input shape print(output.shape) .. testoutput:: (1, 12, 12, 1) """
[docs] def __init__( self, input_res: int, min_res: int, kernel_size: int, initial_filters: int, filters_cap: int, use_dropout: bool = True, dropout_prob: float = 0.3, non_linearity: typing.Type[keras.layers.Layer] = keras.layers.LeakyReLU, normalization_layer: typing.Type[keras.layers.Layer] = InstanceNormalization, use_attention: bool = False, ): """ Patch Discriminator used by pix2pix. When min_res=1 this is the same as a standard fully convolutional discriminator. Args: input_res (int): Input Resolution. min_res (int): Minimum Resolution reached by the discriminator. kernel_size (int): Kernel Size used in Conv Layer. initial_filters (int): number of filters in the first convolutional layer. filters_cap (int): Maximum number of filters. use_dropout (bool): whether to use dropout. dropout_prob (float): probability of dropout. non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in the model. normalization_layer (:class:`tf.keras.layers.Layer`): normalization layer used in the model. use_attention (bool): whether to use attention. """ self.use_attention = use_attention self.layer_count = 0 self.normalization_layer = normalization_layer super().__init__( layer_spec_input_res=input_res, layer_spec_target_res=min_res, kernel_size=kernel_size, initial_filters=initial_filters, filters_cap=filters_cap, output_shape=None, use_dropout=use_dropout, dropout_prob=dropout_prob, non_linearity=non_linearity, ) # concatenate inputs on channel dimension self.concatenate = keras.layers.Concatenate(axis=-1) self.inputs = [1, 1]
[docs] def _add_final_block(self, output_shape): initializer = tf.random_normal_initializer(0.0, 0.02) # last layer mapping to one channel with Linear activation # Notice: The activation is linear since we use the BCE from logits self.model_layers.append(tf.keras.layers.ZeroPadding2D()) self.model_layers.append( tf.keras.layers.Conv2D( 512, self.kernel_size, strides=1, kernel_initializer=initializer, use_bias=False, ) ) self.model_layers.append(self.normalization_layer()) self.model_layers.append(self.non_linearity()) self.model_layers.append(tf.keras.layers.ZeroPadding2D()) self.model_layers.append( tf.keras.layers.Conv2D( 1, self.kernel_size, strides=1, kernel_initializer=initializer ) )
[docs] def call( self, inputs: Union[List, tf.Tensor], training=False, return_features=False ): """Forward pass of the PatchDiscriminator.""" return super().call( inputs=self.concatenate(inputs) if len(inputs) == 2 else inputs, training=training, return_features=return_features, )
[docs] def _add_building_block(self, filters, use_bn=False): """ Construct the core of the :py:obj:`tf.keras.Model`. The layers specified here get added to the :py:obj:`tf.keras.Model` multiple times consuming the hyper-parameters generated in the :func:`_get_layer_spec`. Args: filters (int): Number of filters to use for this iteration of the Building Block. """ initializer = tf.random_normal_initializer(0.0, 0.02) self.model_layers.extend( [ keras.layers.Conv2D( filters, self.kernel_size, strides=(2, 2), padding="same", kernel_initializer=initializer, ) ] ) if len(self.model_layers) > 1: self.model_layers.append(self.normalization_layer()) if self.use_dropout: self.model_layers.append(keras.layers.Dropout(self.dropout_prob)) self.model_layers.append(self.non_linearity()) if self.layer_count == 2 and self.use_attention: self.model_layers.append(Attention(filters)) self.layer_count += 1
[docs]class MultiScaleDiscriminator(tf.keras.Model): """ Multi-Scale discriminator. This discriminator architecture is composed by multiple discriminators working at different scales. Each discriminator is a :py:class:`ashpy.models.convolutional.discriminators.PatchDiscriminator`. Examples: .. testcode:: x = tf.ones((1, 256, 256, 3)) # instantiate the PathDiscriminator multiScaleDiscriminator = MultiScaleDiscriminator(input_res=256, min_res=16, kernel_size=5, initial_filters=64, filters_cap=512, n_discriminators=3 ) # evaluate passing x outputs = multiScaleDiscriminator(x) # the output shape is # the same as the input shape print(len(outputs)) for output in outputs: print(output.shape) .. testoutput:: 3 (1, 12, 12, 1) (1, 12, 12, 1) (1, 12, 12, 1) """
[docs] def __init__( self, input_res: int, min_res: int, kernel_size: int, initial_filters: int, filters_cap: int, use_dropout: bool = True, dropout_prob: float = 0.3, non_linearity: typing.Type[keras.layers.Layer] = keras.layers.LeakyReLU, normalization_layer: typing.Type[keras.layers.Layer] = InstanceNormalization, use_attention: bool = False, n_discriminators: int = 1, ): """ Multi Scale Discriminator. Different generator for different scales of the input image. Used by Pix2PixHD [1]_ . Args: input_res (int): input resolution min_res (int): minimum resolution reached by the discriminators kernel_size (int): kernel size of discriminators initial_filters (int): number of initial filters in the first layer of the discriminators filters_cap (int): maximum number of filters in the discriminators use_dropout (bool): whether to use dropout dropout_prob (float): probability of dropout non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in discriminators normalization_layer (:class:`tf.keras.layers.Layer`): normalization used by the discriminators use_attention (bool): whether to use attention n_discriminators (int): Number of discriminators .. [1] High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs """ super().__init__() self.n_discriminators = n_discriminators self.input_res = input_res self.min_res = min_res self.kernel_size = kernel_size self.initial_filters = initial_filters self.filters_cap = filters_cap self.dropout_prob = dropout_prob self.non_linearity = non_linearity self.use_dropout = use_dropout self.use_attention = use_attention self.normalization_layer = normalization_layer self.discriminators = [] # instantiate the discriminators for i in range(self.n_discriminators): self.discriminators.append( self.build_discriminator(int(input_res / (2 ** i))) ) # subsampling operation self.subsampling = tf.keras.layers.AvgPool2D() # hack in order to accept two inputs self.inputs = [1, 1]
[docs] def build_discriminator(self, input_res) -> ConvDiscriminator: """ Build a single discriminator using parameters defined in this object. Args: input_res: input resolution of the discriminator. Returns: A Discriminator (PatchDiscriminator). """ return PatchDiscriminator( input_res=input_res, min_res=self.min_res, kernel_size=self.kernel_size, initial_filters=self.initial_filters, filters_cap=self.filters_cap, use_dropout=self.use_dropout, dropout_prob=self.dropout_prob, non_linearity=self.non_linearity, use_attention=self.use_attention, normalization_layer=self.normalization_layer, )
[docs] def call( self, inputs: Union[List, tf.Tensor], training=True, return_features=False ) -> Union[List[tf.Tensor], Tuple[List[tf.Tensor], List[tf.Tensor]]]: """ Forward pass of the Multi Scale Discriminator. Args: inputs (:py:class:`tf.Tensor`): input tensor. training (bool): whether is training or not. return_features (bool): whether to return features or not. Returns: ([:py:class:`tf.Tensor`]): A List of Tensors containing the value of D_i for each input. ([:py:class:`tf.Tensor`]): A List of features for each discriminator if `return_features`. """ is_conditioned = isinstance(inputs, list) if is_conditioned: ( fake_or_real, condition, ) = inputs # inputs is a tuple containing the generated images and the conditions else: fake_or_real = inputs condition = None outs = [] features = [] fake_or_real_i = fake_or_real condition_i = condition for i, discriminator in enumerate(self.discriminators): # compute value of the i-th discriminator out, feat = discriminator( [fake_or_real_i, condition_i] if condition_i is not None else fake_or_real_i, training=training, return_features=True, ) # append output and features outs.append(out) features.extend(feat) # reduce input size if i != len(self.discriminators) - 1: fake_or_real_i = self.subsampling(fake_or_real_i) condition_i = ( self.subsampling(condition_i) if condition_i is not None else None ) # handle output values if return_features: return outs, features return outs