# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GAN losses."""
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import TYPE_CHECKING, List, Type, Union
import tensorflow as tf
from ashpy.contexts import GANContext, GANEncoderContext
from ashpy.keras.losses import L1, DHingeLoss, DLeastSquare, DMinMax, GHingeLoss
from ashpy.losses.executor import Executor, SumExecutor
if TYPE_CHECKING:
from ashpy.ashtypes import TWeight
[docs]class AdversarialLossType(Enum):
"""Enumeration for Adversarial Losses. Implemented: GAN and LSGAN."""
GAN = auto() # classical gan loss (minmax)
LSGAN = auto() # Least Square GAN
HINGE_LOSS = auto() # Hinge loss
[docs]class GANExecutor(Executor, ABC):
"""
Executor for GANs.
Implements the basic functions needed by the GAN losses.
"""
[docs] @abstractmethod
def call(self, context, **kwargs):
"""
Execute the function, using the information provided by the context.
Args:
context (:py:class:`ashpy.contexts.Context`): The function
execution Context.
Returns:
:py:obj:`tf.Tensor`: Output Tensor.
"""
super(GANExecutor, self).call(context, **kwargs)
[docs]class GeneratorAdversarialLoss(GANExecutor):
r"""Base class for the adversarial loss of the generator."""
[docs] def __init__(self, loss_fn: tf.keras.losses.Loss = None) -> None:
"""
Initialize the Executor.
Args:
loss_fn (:py:class:`tf.keras.losses.Loss`): Keras Loss function to call
passing (tf.ones_like(d_fake_i), d_fake_i).
"""
super().__init__(loss_fn)
@Executor.reduce_loss
def call(
self,
context: GANContext,
*,
fake: tf.Tensor,
condition: tf.Tensor,
training: bool,
**kwargs,
) -> tf.Tensor:
r"""
Configure the discriminator inputs and calls `loss_fn`.
Args:
context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
fake (:py:class:`tf.Tensor`): Fake images.
condition (:py:class:`tf.Tensor`): Generator conditioning.
training (bool): If training or evaluation.
Returns:
:py:class:`tf.Tensor`: The loss for each example.
"""
fake_inputs = self.get_discriminator_inputs(
context=context, fake_or_real=fake, condition=condition, training=training
)
d_fake = context.discriminator_model(fake_inputs, training=training)
# Support for Multiscale discriminator
# TODO: Improve
if isinstance(d_fake, list):
value = tf.add_n(
[
tf.reduce_mean(
self._fn(tf.ones_like(d_fake_i), d_fake_i), axis=[1, 2]
)
for d_fake_i in d_fake
]
)
return value
value = self._fn(tf.ones_like(d_fake), d_fake)
value = tf.cond(
tf.equal(tf.rank(d_fake), tf.constant(4)),
lambda: value,
lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1),
)
return tf.reduce_mean(value, axis=[1, 2])
[docs]class GeneratorBCE(GeneratorAdversarialLoss):
r"""
The Binary CrossEntropy computed among the generator and the 1 label.
.. math::
L_{G} = E [\log (D( G(z))]
"""
[docs] def __init__(self, from_logits: bool = True) -> None:
"""Initialize the BCE Loss for the Generator."""
self.name = "GeneratorBCE"
super().__init__(tf.keras.losses.BinaryCrossentropy(from_logits=from_logits))
[docs]class GeneratorLSGAN(GeneratorAdversarialLoss):
r"""
Least Square GAN Loss for generator.
Reference: https://arxiv.org/abs/1611.04076
.. note::
Basically the Mean Squared Error between the discriminator output when evaluated
in fake and 1.
.. math::
L_{G} = \frac{1}{2} E [(1 - D(G(z))^2]
"""
[docs] def __init__(self) -> None:
"""Initialize the Least Square Loss for the Generator."""
super().__init__(tf.keras.losses.MeanSquaredError())
self.name = "GeneratorLSGAN"
[docs]class GeneratorL1(GANExecutor):
r"""
L1 loss between the generator output and the target.
.. math::
L_G = E ||x - G(z)||_1
Where x is the target and G(z) is generated image.
"""
[docs] def __init__(self) -> None:
"""Initialize the Executor."""
super().__init__(L1())
@Executor.reduce_loss
def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwargs):
"""
Call the carried loss on `fake` and `real`.
Args:
context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
fake (:py:class:`tf.Tensor`): Fake data (generated).
real (:py:class:`tf.Tensor`): Real data.
Returns:
:py:class:`tf.Tensor`: Output Tensor.
"""
mae = self._fn(fake, real)
return mae
[docs]class GeneratorHingeLoss(GeneratorAdversarialLoss):
r"""
Hinge loss for the Generator.
See Geometric GAN [1]_ for more details.
.. [1] Geometric GAN https://arxiv.org/abs/1705.02894
"""
[docs] def __init__(self) -> None:
"""Initialize the Least Square Loss for the Generator."""
super().__init__(GHingeLoss())
self.name = "GeneratorHingeLoss"
[docs]class FeatureMatchingLoss(GANExecutor):
r"""
Conditional GAN Feature matching loss.
The loss is computed for each example and it's the L1 (MAE) of the feature difference.
Implementation of pix2pix HD: https://github.com/NVIDIA/pix2pixHD
.. math::
\text{FM} = \sum_{i=0}^N \frac{1}{M_i} ||D_i(x, c) - D_i(G(c), c) ||_1
Where:
- D_i is the i-th layer of the discriminator
- N is the total number of layer of the discriminator
- M_i is the number of components for the i-th layer
- x is the target image
- c is the condition
- G(c) is the generated image from the condition c
- || ||_1 stands for norm 1.
This is for a single example: basically for each layer
of the discriminator we compute the absolute error between
the layer evaluated in real examples and in fake examples.
Then we average along the batch. In the case where D_i is
a multidimensional tensor we simply calculate the mean
over the axis 1,2,3.
"""
[docs] def __init__(self) -> None:
"""Initialize the Executor."""
super().__init__(L1())
@Executor.reduce_loss
def call(
self,
context: GANContext,
*,
fake: tf.Tensor,
real: tf.Tensor,
condition: tf.Tensor,
training: bool,
**kwargs,
) -> tf.Tensor:
"""
Configure the discriminator inputs and calls `loss_fn`.
Args:
context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
fake (:py:class:`tf.Tensor`): Fake data.
real (:py:class:`tf.Tensor`): Real data.
condition (:py:class:`tf.Tensor`): Generator conditioning.
training (bool): If training or evaluation.
Returns:
:py:class:`tf.Tensor`: The loss for each example.
"""
fake_inputs = self.get_discriminator_inputs(
context, fake_or_real=fake, condition=condition, training=training
)
real_inputs = self.get_discriminator_inputs(
context, fake_or_real=real, condition=condition, training=training
)
_, features_fake = context.discriminator_model(
fake_inputs, training=training, return_features=True
)
_, features_real = context.discriminator_model(
real_inputs, training=training, return_features=True
)
# for each feature the L1 between the real and the fake
# every call to fn should return [batch_size, 1] that is the mean L1
feature_loss = [
self._fn(feat_real_i, feat_fake_i)
for feat_real_i, feat_fake_i in zip(features_real, features_fake)
]
mae = tf.add_n(feature_loss)
return mae
[docs]class CategoricalCrossEntropy(Executor):
r"""
Categorical Cross Entropy between generator output and target.
Useful when the output of the generator is a distribution over classes.
..note::
The target must be represented in one hot notation.
"""
[docs] def __init__(self) -> None:
"""Initialize the Categorical Cross Entropy Executor."""
self.name = "CrossEntropy"
super().__init__(tf.keras.losses.CategoricalCrossentropy())
@Executor.reduce_loss
def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwargs):
"""
Compute the categorical cross entropy loss.
Args:
context: Unused.
fake (:py:class:`tf.Tensor`): Fake data G(condition).
real (:py:class:`tf.Tensor`): Real data x(c).
Returns:
The categorical cross entropy loss for each example.
"""
loss_value = tf.reduce_mean(self._fn(real, fake), axis=[1, 2])
return loss_value
[docs]class Pix2PixLoss(SumExecutor):
r"""
Pix2Pix Loss.
Weighted sum of :py:class:`ashpy.losses.gan.GeneratorL1`,
:py:class:`ashpy.losses.gan.AdversarialLossG` and
:py:class:`ashpy.losses.gan.FeatureMatchingLoss`.
Used by Pix2Pix [1] and Pix2PixHD [2]
.. [1] Image-to-Image Translation with Conditional Adversarial Networks
https://arxiv.org/abs/1611.07004
.. [2] High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs
https://arxiv.org/abs/1711.11585
"""
[docs] def __init__(
self,
l1_loss_weight: TWeight = 100.0,
adversarial_loss_weight: TWeight = 1.0,
feature_matching_weight: TWeight = 10.0,
adversarial_loss_type: Union[
AdversarialLossType, int
] = AdversarialLossType.GAN,
use_feature_matching_loss: bool = False,
) -> None:
r"""
Initialize the loss.
Weighted sum of :py:class:`ashpy.losses.gan.GeneratorL1`,
:py:class:`ashpy.losses.gan.AdversarialLossG` and
:py:class:`ashpy.losses.gan.FeatureMatchingLoss`.
Args:
l1_loss_weight (:py:obj:`ashpy.ashtypes.TWeight`): Weight of L1 loss.
adversarial_loss_weight (:py:obj:`ashpy.ashtypes.TWeight`): Weight of adversarial loss.
feature_matching_weight (:py:obj:`ashpy.ashtypes.TWeight`): Weight of the
feature matching loss.
adversarial_loss_type (:py:class:`ashpy.losses.gan.AdversarialLossType`): Adversarial
loss type (:py:class:`ashpy.losses.gan.AdversarialLossType.GAN`
or :py:class:`ashpy.losses.gan.AdversarialLossType.LSGAN`).
use_feature_matching_loss (bool): if True use also uses
:py:class:`ashpy.losses.gan.FeatureMatchingLoss`.
"""
executors = [
GeneratorL1() * l1_loss_weight,
get_adversarial_loss_generator(adversarial_loss_type)()
* adversarial_loss_weight,
]
if use_feature_matching_loss:
executors.append(FeatureMatchingLoss() * feature_matching_weight)
super().__init__(executors)
[docs]class Pix2PixLossSemantic(SumExecutor):
r"""
Semantic Pix2Pix Loss.
Weighted sum of :py:class:`ashpy.losses.gan.CategoricalCrossEntropy`,
:py:class:`ashpy.losses.gan.AdversarialLossG` and
:py:class:`ashpy.losses.gan.FeatureMatchingLoss`.
"""
[docs] def __init__(
self,
cross_entropy_weight: TWeight = 100.0,
adversarial_loss_weight: TWeight = 1.0,
feature_matching_weight: TWeight = 10.0,
adversarial_loss_type: AdversarialLossType = AdversarialLossType.GAN,
use_feature_matching_loss: bool = False,
):
r"""
Initialize the Executor.
Weighted sum of :py:class:`ashpy.losses.gan.CategoricalCrossEntropy`,
:py:class:`ashpy.losses.gan.AdversarialLossG`
and :py:class:`ashpy.losses.gan.FeatureMatchingLoss`
Args:
cross_entropy_weight (:py:obj:`ashpy.ashtypes.TWeight`): Weight of the categorical
cross entropy loss.
adversarial_loss_weight (:py:obj:`ashpy.ashtypes.TWeight`): Weight of the
adversarial loss.
feature_matching_weight (:py:obj:`ashpy.ashtypes.TWeight`): Weight of the
feature matching loss.
adversarial_loss_type (:py:class:`ashpy.losses.gan.AdversarialLossType`): type of
adversarial loss, see :py:class:`ashpy.losses.gan.AdversarialLossType`
use_feature_matching_loss (bool): whether to use feature matching loss or not
"""
executors = [
CategoricalCrossEntropy() * cross_entropy_weight,
get_adversarial_loss_generator(adversarial_loss_type)()
* adversarial_loss_weight,
]
if use_feature_matching_loss:
executors.append(FeatureMatchingLoss() * feature_matching_weight)
super().__init__(executors)
# TODO: Check if this supports condition
[docs]class EncoderBCE(Executor):
"""The Binary Cross Entropy computed among the encoder and the 0 label."""
[docs] def __init__(self, from_logits: bool = True) -> None:
"""Initialize the Executor."""
super().__init__(tf.keras.losses.BinaryCrossentropy(from_logits=from_logits))
@Executor.reduce_loss
def call(
self, context: GANEncoderContext, *, real: tf.Tensor, training: bool, **kwargs
):
"""
Compute the Encoder BCE.
Args:
context (:py:class:`ashpy.contexts.GANEncoderContext`): GAN Context
with Encoder support.
real (:py:class:`tf.Tensor`): Real images.
training (bool): If training or evaluation.
Returns:
:py:class:`tf.Tensor`: The loss for each example.
"""
encode = context.encoder_model(real, training=training)
d_real = context.discriminator_model([real, encode], training=training)
return self._fn(tf.zeros_like(d_real), d_real)
[docs]class DiscriminatorAdversarialLoss(GANExecutor):
r"""Base class for the adversarial loss of the discriminator."""
[docs] def __init__(self, loss_fn: tf.keras.losses.Loss = None) -> None:
r"""
Initialize the Executor.
Args:
loss_fn (:py:class:`tf.keras.losses.Loss`): Loss function call passing
(d_real, d_fake).
"""
super().__init__(loss_fn)
@Executor.reduce_loss
def call(
self,
context: GANContext,
*,
fake: tf.Tensor,
real: tf.Tensor,
condition: tf.Tensor,
training: bool,
**kwargs,
):
r"""
Call: setup the discriminator inputs and calls `loss_fn`.
Args:
context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
fake (:py:class:`tf.Tensor`): Fake images corresponding to the condition G(c).
real (:py:class:`tf.Tensor`): Real images corresponding to the condition x(c).
condition (:py:class:`tf.Tensor`): Condition for the generator and discriminator.
training (bool): if training or evaluation
Returns:
:py:class:`tf.Tensor`: The loss for each example.
"""
fake_inputs = self.get_discriminator_inputs(
context, fake_or_real=fake, condition=condition, training=training
)
real_inputs = self.get_discriminator_inputs(
context, fake_or_real=real, condition=condition, training=training
)
d_fake = context.discriminator_model(fake_inputs, training=training)
d_real = context.discriminator_model(real_inputs, training=training)
if isinstance(d_fake, list):
value = tf.add_n(
[
tf.reduce_mean(self._fn(d_real_i, d_fake_i), axis=[1, 2])
for d_real_i, d_fake_i in zip(d_real, d_fake)
]
)
return value
value = self._fn(d_real, d_fake)
value = tf.cond(
tf.equal(tf.rank(d_fake), tf.constant(4)),
lambda: value,
lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1),
)
return tf.reduce_mean(value, axis=[1, 2])
[docs]class DiscriminatorMinMax(DiscriminatorAdversarialLoss):
r"""
The min-max game played by the discriminator.
.. math::
L_{D} = - \frac{1}{2} E [\log(D(x)) + \log (1 - D(G(z))]
"""
[docs] def __init__(self, from_logits=True, label_smoothing=0.0):
"""Initialize Loss."""
super().__init__(
DMinMax(from_logits=from_logits, label_smoothing=label_smoothing)
)
[docs]class DiscriminatorLSGAN(DiscriminatorAdversarialLoss):
r"""
Least square Loss for discriminator.
Reference: Least Squares Generative Adversarial Networks [1]_ .
Basically the Mean Squared Error between
the discriminator output when evaluated in fake samples and 0
and the discriminator output when evaluated in real samples and 1:
For the unconditioned case this is:
.. math::
L_{D} = \frac{1}{2} E[(D(x) - 1)^2 + (0 - D(G(z))^2]
where x are real samples and z is the latent vector.
For the conditioned case this is:
.. math::
L_{D} = \frac{1}{2} E[(D(x, c) - 1)^2 + (0 - D(G(c), c)^2]
where c is the condition and x are real samples.
.. [1] Least Squares Generative Adversarial Networks
https://arxiv.org/abs/1611.04076
"""
[docs] def __init__(self) -> None:
"""Initialize loss."""
super().__init__(DLeastSquare())
self.name = "DiscriminatorLSGAN"
[docs]class DiscriminatorHingeLoss(DiscriminatorAdversarialLoss):
r"""
Hinge loss for the Discriminator.
See Geometric GAN [1]_ for more details.
.. [1] Geometric GAN https://arxiv.org/abs/1705.02894
"""
[docs] def __init__(self) -> None:
"""Initialize the Least Square Loss for the Generator."""
super().__init__(DHingeLoss())
self.name = "DiscriminatorHingeLoss"
###
# Utility functions in order to get the correct loss
###
[docs]def get_adversarial_loss_discriminator(
adversarial_loss_type: Union[AdversarialLossType, int] = AdversarialLossType.GAN
) -> Type[Executor]:
r"""
Return the correct loss fot the Discriminator.
Args:
adversarial_loss_type (:py:class:`ashpy.losses.gan.AdversarialLossType`): Type of loss
(:py:class:`ashpy.losses.gan.AdversarialLossType.GAN` or
:py:class:`ashpy.losses.gan.AdversarialLossType.LSGAN`)
Returns:
The correct (:py:class:`ashpy.losses.executor.Executor`) (to be instantiated).
"""
if adversarial_loss_type == AdversarialLossType.GAN:
return DiscriminatorMinMax
if adversarial_loss_type == AdversarialLossType.LSGAN:
return DiscriminatorLSGAN
if adversarial_loss_type == AdversarialLossType.HINGE_LOSS:
return DiscriminatorHingeLoss
raise ValueError(
"Loss type not supported, the implemented losses are GAN, LSGAN or HINGE_LOSS."
)
[docs]def get_adversarial_loss_generator(
adversarial_loss_type: Union[AdversarialLossType, int] = AdversarialLossType.GAN
) -> Type[Executor]:
r"""
Return the correct loss for the Generator.
Args:
adversarial_loss_type (:py:class:`ashpy.losses.AdversarialLossType`): Type of loss
(:py:class:`ashpy.losses.AdversarialLossType.GAN` or
:py:class:`ashpy.losses.AdversarialLossType.LSGAN`).
Returns:
The correct (:py:class:`ashpy.losses.executor.Executor`) (to be instantiated).
"""
if adversarial_loss_type == AdversarialLossType.GAN:
return GeneratorBCE
if adversarial_loss_type == AdversarialLossType.LSGAN:
return GeneratorLSGAN
if adversarial_loss_type == AdversarialLossType.HINGE_LOSS:
return GeneratorHingeLoss
raise ValueError(
"Loss type not supported, the implemented losses are GAN, LSGAN or HINGE_LOSS."
)