Source code for ashpy.contexts.gan

# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GANContext measures the specified metrics on the GAN."""

from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional

import tensorflow as tf

from ashpy.contexts.context import Context
from ashpy.modes import LogEvalMode

if TYPE_CHECKING:
    from ashpy.losses.executor import Executor
    from ashpy.metrics import Metric


[docs]class GANContext(Context): """:py:class:`ashpy.contexts.gan.GANContext` measure the specified metrics on the GAN."""
[docs] def __init__( self, dataset: tf.data.Dataset = None, generator_model: tf.keras.Model = None, discriminator_model: tf.keras.Model = None, generator_loss: Executor = None, discriminator_loss: Executor = None, metrics: List[Metric] = None, log_eval_mode: LogEvalMode = LogEvalMode.TRAIN, global_step: tf.Variable = tf.Variable( 0, name="global_step", trainable=False, dtype=tf.int64 ), checkpoint: tf.train.Checkpoint = None, ) -> None: """ Initialize the Context. Args: dataset (:py:class:`tf.data.Dataset`): Dataset of tuples. [0] true dataset, [1] generator input dataset. generator_model (:py:class:`tf.keras.Model`): The generator. discriminator_model (:py:class:`tf.keras.Model`): The discriminator. generator_loss (:py:func:`ashpy.losses.Executor`): The generator loss. discriminator_loss (:py:func:`ashpy.losses.Executor`): The discriminator loss. metrics (:obj:`list` of [:py:class:`ashpy.metrics.metric.Metric`]): All the metrics to be used to evaluate the model. log_eval_mode (:py:class:`ashpy.modes.LogEvalMode`): Models' mode to use when evaluating and logging. global_step (:py:class:`tf.Variable`): `tf.Variable` that keeps track of the training steps. checkpoint (:py:class:`tf.train.Checkpoint`): checkpoint to use to keep track of models status. """ super().__init__(metrics, dataset, log_eval_mode, global_step, checkpoint) self._generator_model = generator_model self._discriminator_model = discriminator_model self._generator_loss = generator_loss self._discriminator_loss = discriminator_loss self._fake_samples = None self._generator_inputs = None
@property def generator_model(self) -> tf.keras.Model: """ Retrieve the generator model. Returns: :py:class:`tf.keras.Model`. """ return self._generator_model @property def discriminator_model(self) -> tf.keras.Model: """ Retrieve the discriminator model. Returns: :py:class:`tf.keras.Model`. """ return self._discriminator_model @property def generator_loss(self) -> Optional[Executor]: """Retrieve the generator loss.""" return self._generator_loss @property def discriminator_loss(self) -> Optional[Executor]: """Retrieve the discriminator loss.""" return self._discriminator_loss @property def fake_samples(self) -> Optional[tf.Tensor]: """Retrieve the fake samples, i.e. output of the generator.""" return self._fake_samples @fake_samples.setter def fake_samples(self, _fake_samples: Optional[tf.Tensor]): """Set the fake samples, i.e. output of the generator.""" self._fake_samples = _fake_samples @property def generator_inputs(self) -> Optional[tf.Tensor]: """Retrieve the generator inputs.""" return self._generator_inputs @generator_inputs.setter def generator_inputs(self, _generator_inputs: Optional[tf.Tensor]): """Set the generator inputs.""" self._generator_inputs = _generator_inputs
[docs]class GANEncoderContext(GANContext): """:py:class:`ashpy.contexts.gan.GANEncoderContext` measure the specified metrics on the GAN."""
[docs] def __init__( self, dataset: tf.data.Dataset = None, generator_model: tf.keras.Model = None, discriminator_model: tf.keras.Model = None, encoder_model: tf.keras.Model = None, generator_loss: Executor = None, discriminator_loss: Executor = None, encoder_loss: Executor = None, metrics: List[Metric] = None, log_eval_mode: LogEvalMode = LogEvalMode.TRAIN, global_step: tf.Variable = tf.Variable( 0, name="global_step", trainable=False, dtype=tf.int64 ), checkpoint: tf.train.Checkpoint = None, ) -> None: r""" Initialize the Context. Args: dataset (:py:class:`tf.data.Dataset`): Dataset of tuples. [0] true dataset, [1] generator input dataset. generator_model (:py:class:`tf.keras.Model`): The generator. discriminator_model (:py:class:`tf.keras.Model`): The discriminator. encoder_model (:py:class:`tf.keras.Model`): The encoder. generator_loss (:py:func:`ashpy.losses.Executor`): The generator loss. discriminator_loss (:py:func:`ashpy.losses.Executor`): The discriminator loss. encoder_loss (:py:func:`ashpy.losses.Executor`): The encoder loss. metrics (:obj:`list` of [:py:class:`ashpy.metrics.metric.Metric`]): All the metrics to be used to evaluate the model. log_eval_mode (:py:class:`ashpy.modes.LogEvalMode`): Models' mode to use when evaluating and logging. global_step (:py:class:`tf.Variable`): `tf.Variable` that keeps track of the training steps. checkpoint (:py:class:`tf.train.Checkpoint`): checkpoint to use to keep track of models status. """ super().__init__( dataset=dataset, generator_model=generator_model, discriminator_model=discriminator_model, generator_loss=generator_loss, discriminator_loss=discriminator_loss, metrics=metrics, log_eval_mode=log_eval_mode, global_step=global_step, checkpoint=checkpoint, ) self._encoder_model = encoder_model self._encoder_loss = encoder_loss self._generator_of_encoder = None self._encoder_inputs = None
@property def encoder_model(self) -> tf.keras.Model: """ Retrieve the encoder model. Returns: :py:class:`tf.keras.Model`. """ return self._encoder_model @property def encoder_loss(self) -> Optional[Executor]: """Retrieve the encoder loss.""" return self._encoder_loss @property def generator_of_encoder(self) -> tf.Tensor: """Retrieve the images generated from the encoder output.""" return self._generator_of_encoder @generator_of_encoder.setter def generator_of_encoder(self, _generator_of_encoder: tf.Tensor): """Set the images generated from the encoder output.""" self._generator_of_encoder = _generator_of_encoder @property def encoder_inputs(self) -> tf.Tensor: """Retrieve the inputs of the encoder.""" return self._encoder_inputs @encoder_inputs.setter def encoder_inputs(self, _encoder_inputs: tf.Tensor): """Setter for the inputs of the encoder.""" self._encoder_inputs = _encoder_inputs