Source code for ashpy.metrics.gan

"""GAN metrics."""
from __future__ import annotations

import operator
import os
import types
from typing import TYPE_CHECKING, Callable

import tensorflow as tf
import tensorflow_hub as hub
from ashpy.metrics import ClassifierMetric, Metric
from ashpy.modes import LogEvalMode

    from ashpy.contexts import (  # pylint: disable=ungrouped-imports

[docs]class DiscriminatorLoss(Metric): """The Discriminator loss value."""
[docs] def __init__( self, name: str = "d_loss", model_selection_operator: Callable = None, logdir: str = os.path.join(os.getcwd(), "log"), ) -> None: """ Initialize the Metric. Args: name (str): Name of the metric. model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an operator is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ super().__init__( name=name, metric=tf.metrics.Mean(name=name, dtype=tf.float32), model_selection_operator=model_selection_operator, logdir=logdir, )
[docs] def update_state(self, context: GANContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.gan.GANContext`): An AshPy Context Object that carries all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for real_xy, noise in context.dataset: real_x, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN ) loss = context.discriminator_loss( context, fake=fake, real=real_x, condition=real_y, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run_v2(updater(loss))
[docs]class GeneratorLoss(Metric): """Generator loss value."""
[docs] def __init__( self, name: str = "g_loss", model_selection_operator: Callable = None, logdir: str = os.path.join(os.getcwd(), "log"), ): """ Initialize the Metric. Args: name (str): Name of the metric. model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an operator is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ super().__init__( name=name, metric=tf.metrics.Mean(name=name, dtype=tf.float32), model_selection_operator=model_selection_operator, logdir=logdir, )
[docs] def update_state(self, context: GANContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.GANContext`): An AshPy Context Object that carries all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for real_xy, noise in context.dataset: real_x, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN ) loss = context.generator_loss( context, fake=fake, real=real_x, condition=real_y, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run_v2(updater(loss))
[docs]class EncoderLoss(Metric): """Encoder Loss value."""
[docs] def __init__( self, name: str = "e_loss", model_selection_operator: Callable = None, logdir: str = os.path.join(os.getcwd(), "log"), ) -> None: """ Initialize the Metric. Args: name (str): Name of the metric. model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an operator is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ super().__init__( name=name, metric=tf.metrics.Mean(name=name, dtype=tf.float32), model_selection_operator=model_selection_operator, logdir=logdir, )
[docs] def update_state(self, context: GANEncoderContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.gan.GANEncoderContext`): An AshPy Context Object that carries all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for real_xy, noise in context.dataset: real_x, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN ) loss = context.encoder_loss( context, fake=fake, real=real_x, condition=real_y, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run_v2(updater(loss))
[docs]class InceptionScore(Metric): r""" Inception Score Metric. This class is an implementation of the Inception Score technique for evaluating a GAN. See Improved Techniques for Training GANs [1]_. .. [1] Improved Techniques for Training GANs """
[docs] def __init__( self, inception: tf.keras.Model, name: str = "inception_score",, logdir=os.path.join(os.getcwd(), "log"), ): """ Initialize the Metric. Args: inception (:py:class:`tf.keras.Model`): Keras Inception model. name (str): Name of the metric. model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an operator is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ super().__init__( name=name, metric=tf.metrics.Mean(name), model_selection_operator=model_selection_operator, logdir=logdir, ) self._inception_model = inception # add softmax layer if not present if "softmax" not in self._inception_model.layers[-1].name.lower(): self._inception_model = tf.keras.Sequential( [self._inception_model, tf.keras.layers.Softmax()] )
[docs] def update_state(self, context: GANContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) # Generate the images created with the AshPy Context's generator for real_xy, noise in context.dataset: _, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN ) # rescale images between 0 and 1 fake = (fake + 1.0) / 2.0 # Resize images to 299x299 fake = tf.image.resize(fake, (299, 299)) try: fake = tf.image.grayscale_to_rgb(fake) except ValueError: # Images are already RGB pass # Calculate the inception score inception_score_per_batch = self.inception_score(fake) # Update the Mean metric created for this context # self._metric.update_state(mean) self._distribute_strategy.experimental_run_v2( updater(inception_score_per_batch) )
[docs] def inception_score(self, images: tf.Tensor) -> tf.Tensor: """ Compute the Inception Score. Args: images (:py:obj:`list` of [:py:class:`numpy.ndarray`]): A list of ndarray of generated images of 299x299 of size. Returns: :obj:`tuple` of (:py:class:`numpy.ndarray`, :py:class:`numpy.ndarray`): Mean and STD. """ tf.print("Computing inception score...") predictions: tf.Tensor = self._inception_model(images) kl_divergence = predictions * ( tf.math.log(predictions) - tf.math.log(tf.math.reduce_mean(predictions, axis=0, keepdims=True)) ) kl_divergence = tf.math.reduce_mean(tf.math.reduce_sum(kl_divergence, axis=1)) inception_score_per_batch = tf.math.exp(kl_divergence) return inception_score_per_batch
[docs] @staticmethod def get_or_train_inception( dataset:, name: str, num_classes: int, epochs: int, fine_tuning: bool = False, loss_fn: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True ), optimizer: tf.keras.optimizers.Adam = tf.keras.optimizers.Adam(1e-5), logdir: str = os.path.join(os.getcwd(), "log"), ) -> tf.keras.Model: """ Restore or train (and save) the Inception model. Args: dataset (:py:class:``): Dataset to re-train Inception Model on. name (str): Name of this new Inception Model, used for saving it. num_classes (int): Number of classes to use for classification. epochs (int): Epochs to train the Inception model for. fine_tuning (bool): Controls wether the model will be fine-tuned or used as is. loss_fn (:py:class:`tf.keras.losses.Loss`): Keras Loss for the model. optimizer (:py:class:`tf.keras.optimizers.Optimizer`): Keras optimizer for the model. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. Returns: :py:class:`tf.keras.Model`: The Inception Model. """ os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "1" model = tf.keras.Sequential( [ hub.KerasLayer( "", output_shape=[2048], trainable=fine_tuning, ), tf.keras.layers.Dense(512), tf.keras.layers.LeakyReLU(alpha=0.05), tf.keras.layers.Dense(num_classes), ] ) del os.environ["TFHUB_DOWNLOAD_PROGRESS"] step = tf.Variable(0, trainable=False, dtype=tf.int64) ckpt = tf.train.Checkpoint() ckpt.objects = [] ckpt.objects.extend([model, step]) logdir = logdir manager = tf.train.CheckpointManager( ckpt, os.path.join(logdir, "inception", name), max_to_keep=1 ) if manager.latest_checkpoint: ckpt.restore(manager.latest_checkpoint) print(f"Restored checkpoint {manager.latest_checkpoint}.") return model print("Training the InceptionV3 model") # callback checkpoint model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(logdir) model.compile(loss=loss_fn, optimizer=optimizer), epochs=epochs, callbacks=[model_checkpoint_callback]) return model
[docs]class EncodingAccuracy(ClassifierMetric): """ Generator and Encoder accuracy performance. Measure the Generator and Encoder performance together, by classifying: `G(E(x)), y` using a pre-trained classified (on the dataset of x). """
[docs] def __init__( self, classifier: tf.keras.Model, name: str = "encoding_accuracy", model_selection_operator: Callable = None, logdir=os.path.join(os.getcwd(), "log"), ) -> None: """ Measure the Generator and Encoder performance together. This is done by classifying: `G(E(x)), y` using a pre-trained classified (on the dataset of x). Args: classifier (:py:class:`tf.keras.Model`): Keras Model to use as a Classifier to measure the accuracy. Generally assumed to be the Inception Model. name (str): Name of the metric. model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an operator is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ super().__init__( metric=tf.metrics.Accuracy(name), model_selection_operator=model_selection_operator, logdir=logdir, ) self._classifier = classifier
[docs] def update_state(self, context: GANEncoderContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.GANEncoderContext`): An AshPy Context Object that carries all the information the Metric needs. """ inner_context = types.SimpleNamespace() inner_context.classifier_model = self._classifier inner_context.log_eval_mode = LogEvalMode.TEST # Return G(E(x)), y def _gen(real_xy, _): real_x, real_y = real_xy out = context.generator_model( context.encoder_model( real_x, training=context.log_eval_mode == LogEvalMode.TRAIN ), training=context.log_eval_mode == LogEvalMode.TRAIN, ) return out, real_y dataset = inner_context.dataset = dataset # Classify using the pre-trained classifier (self._classifier) # G(E(x)) and check the accuracy (with y) super().update_state(inner_context)