Source code for ashpy.callbacks.gan

# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
GAN callbacks.

LogImageGANCallback: Log output of the generator when evaluated in its inputs.
LogImageGANEncoderCallback: Log output of the generator when evaluated in the encoder
"""

import tensorflow as tf

from ashpy import LogEvalMode
from ashpy.callbacks.counter_callback import CounterCallback
from ashpy.callbacks.events import Event
from ashpy.contexts import GANContext, GANEncoderContext
from ashpy.utils.utils import log


[docs]class LogImageGANCallback(CounterCallback): """ Callback used for logging GANs images to Tensorboard. Logs the Generator output. Logs G(z). Examples: .. testcode:: import shutil import operator import os generator = models.gans.ConvGenerator( layer_spec_input_res=(7, 7), layer_spec_target_res=(28, 28), kernel_size=(5, 5), initial_filters=32, filters_cap=16, channels=1, ) discriminator = models.gans.ConvDiscriminator( layer_spec_input_res=(28, 28), layer_spec_target_res=(7, 7), kernel_size=(5, 5), initial_filters=16, filters_cap=32, output_shape=1, ) # Losses generator_bce = losses.gan.GeneratorBCE() minmax = losses.gan.DiscriminatorMinMax() # Real data batch_size = 2 mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,)) # Trainer epochs = 2 logdir = "testlog/callbacks" callbacks = [callbacks.LogImageGANCallback()] trainer = trainers.gan.AdversarialTrainer( generator=generator, discriminator=discriminator, generator_optimizer=tf.optimizers.Adam(1e-4), discriminator_optimizer=tf.optimizers.Adam(1e-4), generator_loss=generator_bce, discriminator_loss=minmax, epochs=epochs, callbacks=callbacks, logdir=logdir, ) # take only 2 samples to speed up tests real_data = ( tf.data.Dataset.from_tensor_slices( (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size) .batch(batch_size) .prefetch(1) ) # Add noise in the same dataset, just by mapping. # The return type of the dataset must be: tuple(tuple(a,b), noise) dataset = real_data.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100))) ) trainer(dataset) shutil.rmtree(logdir) assert not os.path.exists(logdir) trainer._global_step.assign_add(500) .. testoutput:: Initializing checkpoint. Starting epoch 1. [1] Saved checkpoint: testlog/callbacks/ckpts/ckpt-1 Epoch 1 completed. Starting epoch 2. [2] Saved checkpoint: testlog/callbacks/ckpts/ckpt-2 Epoch 2 completed. Training finished after 2 epochs. """
[docs] def __init__( self, event: Event = Event.ON_EPOCH_END, name="LogImageGANCallback", event_freq: int = 1, ): """ Initialize the LogImageCallbackGAN. Args: event (:py:class:`ashpy.callbacks.events.Event`): event to consider. event_freq (int): frequency of logging. name (str): name of the callback. """ super(LogImageGANCallback, self).__init__( event=event, fn=self._log_fn, name=name, event_freq=event_freq )
[docs] def _log_fn(self, context: GANContext) -> None: """ Log output of the generator to Tensorboard. Args: context (:py:class:`ashpy.contexts.gan.GANContext`): current context. """ if context.log_eval_mode == LogEvalMode.TEST: out = context.generator_model(context.generator_inputs, training=False) elif context.log_eval_mode == LogEvalMode.TRAIN: out = context.fake_samples else: raise ValueError("Invalid LogEvalMode") # tensorboard 2.0 does not support float images in [-1, 1] # only in [0,1] if out.dtype == tf.float32: # The hypothesis is that image are in [-1,1] how to check? out = (out + 1.0) / 2 log("generator", out, context.global_step)
[docs]class LogImageGANEncoderCallback(LogImageGANCallback): """ Callback used for logging GANs images to Tensorboard. Logs the Generator output evaluated in the encoder output. Logs G(E(x)). Examples: .. testcode:: import shutil import operator def real_gen(): label = 0 for _ in tf.range(100): yield ((10.0,), (label,)) latent_dim = 100 generator = tf.keras.Sequential([tf.keras.layers.Dense(1)]) left_input = tf.keras.layers.Input(shape=(1,)) left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input) right_input = tf.keras.layers.Input(shape=(latent_dim,)) right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input) net = tf.keras.layers.Concatenate()([left, right]) out = tf.keras.layers.Dense(1)(net) discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out]) encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)]) # Losses generator_bce = losses.gan.GeneratorBCE() encoder_bce = losses.gan.EncoderBCE() minmax = losses.gan.DiscriminatorMinMax() epochs = 2 callbacks = [callbacks.LogImageGANEncoderCallback()] logdir = "testlog/callbacks_encoder" trainer = trainers.gan.EncoderTrainer( generator=generator, discriminator=discriminator, encoder=encoder, discriminator_optimizer=tf.optimizers.Adam(1e-4), generator_optimizer=tf.optimizers.Adam(1e-5), encoder_optimizer=tf.optimizers.Adam(1e-6), generator_loss=generator_bce, discriminator_loss=minmax, encoder_loss=encoder_bce, epochs=epochs, callbacks=callbacks, logdir=logdir, ) batch_size = 10 discriminator_input = tf.data.Dataset.from_generator( real_gen, (tf.float32, tf.int64), ((1), (1)) ).batch(batch_size) dataset = discriminator_input.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim))) ) trainer(dataset) shutil.rmtree(logdir) .. testoutput:: Initializing checkpoint. Starting epoch 1. [10] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-1 Epoch 1 completed. Starting epoch 2. [20] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-2 Epoch 2 completed. Training finished after 2 epochs. """
[docs] def __init__( self, event: Event = Event.ON_EPOCH_END, name="LogImageGANEncoderCallback", event_freq: int = 1, ): """ Initialize the LogImageCallbackGAN. Args: event (:py:class:`ashpy.callbacks.events.Event`): event to consider. event_freq (int): frequency of logging. name (str): name of the callback. """ super(LogImageGANEncoderCallback, self).__init__( event=event, name=name, event_freq=event_freq )
[docs] def _log_fn(self, context: GANEncoderContext): """ Log output of the generator to Tensorboard. Logs G(E(x)). Args: context (:py:class:`ashpy.contexts.gan.GanEncoderContext`): current context. """ if context.log_eval_mode == LogEvalMode.TEST: generator_of_encoder = context.generator_model( context.encoder_model(context.encoder_inputs, training=False), training=False, ) elif context.log_eval_mode == LogEvalMode.TRAIN: generator_of_encoder = context.generator_of_encoder else: raise ValueError("Invalid LogEvalMode") # Tensorboard 2.0 does not support float images in [-1, 1] # Only in [0,1] if generator_of_encoder.dtype == tf.float32: # The hypothesis is that image are in [-1,1] how to check? generator_of_encoder = (generator_of_encoder + 1.0) / 2 log("generator_of_encoder", generator_of_encoder, context.global_step)