gan¶
GAN losses.
Functions
get_adversarial_loss_discriminator |
Return the correct loss fot the Discriminator. |
get_adversarial_loss_generator |
Return the correct loss for the Generator. |
Classes
AdversarialLossType |
Enumeration for Adversarial Losses. |
CategoricalCrossEntropy |
Categorical Cross Entropy between generator output and target. |
DiscriminatorAdversarialLoss |
Base class for the adversarial loss of the discriminator. |
DiscriminatorHingeLoss |
Hinge loss for the Discriminator. |
DiscriminatorLSGAN |
Least square Loss for discriminator. |
DiscriminatorMinMax |
The min-max game played by the discriminator. |
EncoderBCE |
The Binary Cross Entropy computed among the encoder and the 0 label. |
FeatureMatchingLoss |
Conditional GAN Feature matching loss. |
GANExecutor |
Executor for GANs. |
GeneratorAdversarialLoss |
Base class for the adversarial loss of the generator. |
GeneratorBCE |
The Binary CrossEntropy computed among the generator and the 1 label. |
GeneratorHingeLoss |
Hinge loss for the Generator. |
GeneratorL1 |
L1 loss between the generator output and the target. |
GeneratorLSGAN |
Least Square GAN Loss for generator. |
Pix2PixLoss |
Pix2Pix Loss. |
Pix2PixLossSemantic |
Semantic Pix2Pix Loss. |
-
class
ashpy.losses.gan.
AdversarialLossType
[source]¶ Bases:
enum.Enum
Enumeration for Adversarial Losses. Implemented: GAN and LSGAN.
-
class
ashpy.losses.gan.
CategoricalCrossEntropy
[source]¶ Bases:
ashpy.losses.executor.Executor
Categorical Cross Entropy between generator output and target.
Useful when the output of the generator is a distribution over classes.
- ..note::
- The target must be represented in one hot notation.
-
class
ashpy.losses.gan.
DiscriminatorAdversarialLoss
(loss_fn=None)[source]¶ Bases:
ashpy.losses.gan.GANExecutor
Base class for the adversarial loss of the discriminator.
-
__init__
(loss_fn=None)[source]¶ Initialize the Executor.
Parameters: - loss_fn (
tf.keras.losses.Loss
) – Loss function call passing - d_fake) ((d_real,) –
Return type: None
- loss_fn (
-
-
class
ashpy.losses.gan.
DiscriminatorHingeLoss
[source]¶ Bases:
ashpy.losses.gan.DiscriminatorAdversarialLoss
Hinge loss for the Discriminator.
See Geometric GAN [1]_ for more details.
[1] Geometric GAN https://arxiv.org/abs/1705.02894
-
class
ashpy.losses.gan.
DiscriminatorLSGAN
[source]¶ Bases:
ashpy.losses.gan.DiscriminatorAdversarialLoss
Least square Loss for discriminator.
Reference: Least Squares Generative Adversarial Networks [1]_ .
Basically the Mean Squared Error between the discriminator output when evaluated in fake samples and 0 and the discriminator output when evaluated in real samples and 1: For the unconditioned case this is:
\[L_{D} = \frac{1}{2} E[(D(x) - 1)^2 + (0 - D(G(z))^2]\]where x are real samples and z is the latent vector.
For the conditioned case this is:
\[L_{D} = \frac{1}{2} E[(D(x, c) - 1)^2 + (0 - D(G(c), c)^2]\]where c is the condition and x are real samples.
[1] Least Squares Generative Adversarial Networks https://arxiv.org/abs/1611.04076
-
class
ashpy.losses.gan.
DiscriminatorMinMax
(from_logits=True, label_smoothing=0.0)[source]¶ Bases:
ashpy.losses.gan.DiscriminatorAdversarialLoss
The min-max game played by the discriminator.
\[L_{D} = - \frac{1}{2} E [\log(D(x)) + \log (1 - D(G(z))]\]
-
class
ashpy.losses.gan.
EncoderBCE
(from_logits=True)[source]¶ Bases:
ashpy.losses.executor.Executor
The Binary Cross Entropy computed among the encoder and the 0 label.
-
class
ashpy.losses.gan.
FeatureMatchingLoss
[source]¶ Bases:
ashpy.losses.gan.GANExecutor
Conditional GAN Feature matching loss.
The loss is computed for each example and it’s the L1 (MAE) of the feature difference. Implementation of pix2pix HD: https://github.com/NVIDIA/pix2pixHD
\[\text{FM} = \sum_{i=0}^N \frac{1}{M_i} ||D_i(x, c) - D_i(G(c), c) ||_1\]Where:
- D_i is the i-th layer of the discriminator
- N is the total number of layer of the discriminator
- M_i is the number of components for the i-th layer
- x is the target image
- c is the condition
- G(c) is the generated image from the condition c
- || ||_1 stands for norm 1.
This is for a single example: basically for each layer of the discriminator we compute the absolute error between the layer evaluated in real examples and in fake examples. Then we average along the batch. In the case where D_i is a multidimensional tensor we simply calculate the mean over the axis 1,2,3.
-
class
ashpy.losses.gan.
GANExecutor
(fn=None)[source]¶ Bases:
ashpy.losses.executor.Executor
,abc.ABC
Executor for GANs.
Implements the basic functions needed by the GAN losses.
-
call
(context, **kwargs)[source]¶ Execute the function, using the information provided by the context.
Parameters: context ( ashpy.contexts.Context
) – The function execution Context.Returns: tf.Tensor
– Output Tensor.
-
static
get_discriminator_inputs
(context, fake_or_real, condition, training)[source]¶ Return the discriminator inputs. If needed it uses the encoder.
The current implementation uses the number of inputs to determine whether the discriminator is conditioned or not.
Parameters: - context (
ashpy.contexts.gan.GANContext
) – Context for GAN models. - fake_or_real (
tf.Tensor
) – Discriminator input tensor, it can be fake (generated) or real. - condition (
tf.Tensor
) – Discriminator condition (it can also be generator noise). - training (
bool
) – whether is training phase or not
Return type: Returns: The discriminator inputs.
- context (
-
-
class
ashpy.losses.gan.
GeneratorAdversarialLoss
(loss_fn=None)[source]¶ Bases:
ashpy.losses.gan.GANExecutor
Base class for the adversarial loss of the generator.
-
__init__
(loss_fn=None)[source]¶ Initialize the Executor.
Parameters: loss_fn ( tf.keras.losses.Loss
) – Keras Loss function to call passing (tf.ones_like(d_fake_i), d_fake_i).Return type: None
-
-
class
ashpy.losses.gan.
GeneratorBCE
(from_logits=True)[source]¶ Bases:
ashpy.losses.gan.GeneratorAdversarialLoss
The Binary CrossEntropy computed among the generator and the 1 label.
\[L_{G} = E [\log (D( G(z))]\]
-
class
ashpy.losses.gan.
GeneratorHingeLoss
[source]¶ Bases:
ashpy.losses.gan.GeneratorAdversarialLoss
Hinge loss for the Generator.
See Geometric GAN [1]_ for more details.
[1] Geometric GAN https://arxiv.org/abs/1705.02894
-
class
ashpy.losses.gan.
GeneratorL1
[source]¶ Bases:
ashpy.losses.gan.GANExecutor
L1 loss between the generator output and the target.
\[L_G = E ||x - G(z)||_1\]Where x is the target and G(z) is generated image.
-
class
ashpy.losses.gan.
GeneratorLSGAN
[source]¶ Bases:
ashpy.losses.gan.GeneratorAdversarialLoss
Least Square GAN Loss for generator.
Reference: https://arxiv.org/abs/1611.04076
Note
Basically the Mean Squared Error between the discriminator output when evaluated in fake and 1.
\[L_{G} = \frac{1}{2} E [(1 - D(G(z))^2]\]
-
class
ashpy.losses.gan.
Pix2PixLoss
(l1_loss_weight=100.0, adversarial_loss_weight=1.0, feature_matching_weight=10.0, adversarial_loss_type=<AdversarialLossType.GAN: 1>, use_feature_matching_loss=False)[source]¶ Bases:
ashpy.losses.executor.SumExecutor
Pix2Pix Loss.
Weighted sum of
ashpy.losses.gan.GeneratorL1
,ashpy.losses.gan.AdversarialLossG
andashpy.losses.gan.FeatureMatchingLoss
.Used by Pix2Pix [1] and Pix2PixHD [2]
[1] Image-to-Image Translation with Conditional Adversarial Networks https://arxiv.org/abs/1611.07004 [2] High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs https://arxiv.org/abs/1711.11585 -
__init__
(l1_loss_weight=100.0, adversarial_loss_weight=1.0, feature_matching_weight=10.0, adversarial_loss_type=<AdversarialLossType.GAN: 1>, use_feature_matching_loss=False)[source]¶ Initialize the loss.
Weighted sum of
ashpy.losses.gan.GeneratorL1
,ashpy.losses.gan.AdversarialLossG
andashpy.losses.gan.FeatureMatchingLoss
.Parameters: - l1_loss_weight (
ashpy.ashtypes.TWeight
) – Weight of L1 loss. - adversarial_loss_weight (
ashpy.ashtypes.TWeight
) – Weight of adversarial loss. - feature_matching_weight (
ashpy.ashtypes.TWeight
) – Weight of the feature matching loss. - adversarial_loss_type (
ashpy.losses.gan.AdversarialLossType
) – Adversarial loss type (ashpy.losses.gan.AdversarialLossType.GAN
orashpy.losses.gan.AdversarialLossType.LSGAN
). - use_feature_matching_loss (bool) – if True use also uses
ashpy.losses.gan.FeatureMatchingLoss
.
Return type: - l1_loss_weight (
-
-
class
ashpy.losses.gan.
Pix2PixLossSemantic
(cross_entropy_weight=100.0, adversarial_loss_weight=1.0, feature_matching_weight=10.0, adversarial_loss_type=<AdversarialLossType.GAN: 1>, use_feature_matching_loss=False)[source]¶ Bases:
ashpy.losses.executor.SumExecutor
Semantic Pix2Pix Loss.
Weighted sum of
ashpy.losses.gan.CategoricalCrossEntropy
,ashpy.losses.gan.AdversarialLossG
andashpy.losses.gan.FeatureMatchingLoss
.-
__init__
(cross_entropy_weight=100.0, adversarial_loss_weight=1.0, feature_matching_weight=10.0, adversarial_loss_type=<AdversarialLossType.GAN: 1>, use_feature_matching_loss=False)[source]¶ Initialize the Executor.
Weighted sum of
ashpy.losses.gan.CategoricalCrossEntropy
,ashpy.losses.gan.AdversarialLossG
andashpy.losses.gan.FeatureMatchingLoss
Parameters: - cross_entropy_weight (
ashpy.ashtypes.TWeight
) – Weight of the categorical cross entropy loss. - adversarial_loss_weight (
ashpy.ashtypes.TWeight
) – Weight of the adversarial loss. - feature_matching_weight (
ashpy.ashtypes.TWeight
) – Weight of the feature matching loss. - adversarial_loss_type (
ashpy.losses.gan.AdversarialLossType
) – type of adversarial loss, seeashpy.losses.gan.AdversarialLossType
- use_feature_matching_loss (bool) – whether to use feature matching loss or not
- cross_entropy_weight (
-
-
ashpy.losses.gan.
get_adversarial_loss_discriminator
(adversarial_loss_type=<AdversarialLossType.GAN: 1>)[source]¶ Return the correct loss fot the Discriminator.
Parameters: adversarial_loss_type ( ashpy.losses.gan.AdversarialLossType
) – Type of loss (ashpy.losses.gan.AdversarialLossType.GAN
orashpy.losses.gan.AdversarialLossType.LSGAN
)Return type: Type
[Executor
]Returns: The correct ( ashpy.losses.executor.Executor
) (to be instantiated).
-
ashpy.losses.gan.
get_adversarial_loss_generator
(adversarial_loss_type=<AdversarialLossType.GAN: 1>)[source]¶ Return the correct loss for the Generator.
Parameters: adversarial_loss_type ( ashpy.losses.AdversarialLossType
) – Type of loss (ashpy.losses.AdversarialLossType.GAN
orashpy.losses.AdversarialLossType.LSGAN
).Return type: Type
[Executor
]Returns: The correct ( ashpy.losses.executor.Executor
) (to be instantiated).