# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Collection of GANs trainers."""
from pathlib import Path
from typing import List, Optional, Tuple, Union
import ashpy.restorers
import tensorflow as tf
from ashpy.callbacks import Callback
from ashpy.contexts.gan import GANContext, GANEncoderContext
from ashpy.datasets import wrap
from ashpy.losses.executor import Executor
from ashpy.metrics import Metric
from ashpy.metrics.gan import DiscriminatorLoss, EncoderLoss, GeneratorLoss
from ashpy.modes import LogEvalMode
from ashpy.trainers.trainer import Trainer
__ALL__ = ["AdversarialTrainer", "EncoderTrainer"]
[docs]class AdversarialTrainer(Trainer):
r"""
Primitive Trainer for GANs subclassed from :class:`ashpy.trainers.Trainer`.
Examples:
.. testcode::
import shutil
import operator
generator = models.gans.ConvGenerator(
layer_spec_input_res=(7, 7),
layer_spec_target_res=(28, 28),
kernel_size=(5, 5),
initial_filters=32,
filters_cap=16,
channels=1,
)
discriminator = models.gans.ConvDiscriminator(
layer_spec_input_res=(28, 28),
layer_spec_target_res=(7, 7),
kernel_size=(5, 5),
initial_filters=16,
filters_cap=32,
output_shape=1,
)
# Losses
generator_bce = losses.gan.GeneratorBCE()
minmax = losses.gan.DiscriminatorMinMax()
# Real data
batch_size = 2
mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,))
# Trainer
epochs = 2
logdir = "testlog/adversarial"
metrics = [
metrics.gan.InceptionScore(
# Fake inception model
models.gans.ConvDiscriminator(
layer_spec_input_res=(299, 299),
layer_spec_target_res=(7, 7),
kernel_size=(5, 5),
initial_filters=16,
filters_cap=32,
output_shape=10,
),
model_selection_operator=operator.gt,
)
]
trainer = trainers.gan.AdversarialTrainer(
generator=generator,
discriminator=discriminator,
generator_optimizer=tf.optimizers.Adam(1e-4),
discriminator_optimizer=tf.optimizers.Adam(1e-4),
generator_loss=generator_bce,
discriminator_loss=minmax,
epochs=epochs,
metrics=metrics,
logdir=logdir,
)
# take only 2 samples to speed up tests
real_data = (
tf.data.Dataset.from_tensor_slices(
(tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size)
.batch(batch_size)
.prefetch(1)
)
# Add noise in the same dataset, just by mapping.
# The return type of the dataset must be: tuple(tuple(a,b), noise)
dataset = real_data.map(
lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100)))
)
trainer(dataset)
shutil.rmtree(logdir)
.. testoutput::
Initializing checkpoint.
Starting epoch 1.
[1] Saved checkpoint: testlog/adversarial/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[2] Saved checkpoint: testlog/adversarial/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
"""
ckpt_id_generator: str = "generator"
ckpt_id_discriminator: str = "discriminator"
ckpt_id_optimizer_generator: str = "optimizer_generator"
ckpt_id_optimizer_discriminator: str = "optimizer_discriminator"
[docs] def __init__(
self,
generator: tf.keras.Model,
discriminator: tf.keras.Model,
generator_optimizer: tf.optimizers.Optimizer,
discriminator_optimizer: tf.optimizers.Optimizer,
generator_loss: Executor,
discriminator_loss: Executor,
epochs: int,
metrics: Optional[Union[Tuple[Metric], List[Metric]]] = None,
callbacks: Optional[List[Callback]] = None,
logdir: Union[Path, str] = Path().cwd() / "log",
log_eval_mode: LogEvalMode = LogEvalMode.TEST,
global_step: Optional[tf.Variable] = None,
):
r"""
Instantiate a :py:class:`AdversarialTrainer`.
Args:
generator (:py:class:`tf.keras.Model`): A :py:class:`tf.keras.Model`
describing the Generator part of a GAN.
discriminator (:py:class:`tf.keras.Model`): A :py:class:`tf.keras.Model`
describing the Discriminator part of a GAN.
generator_optimizer (:py:class:`tf.optimizers.Optimizer`): A :py:mod:`tf.optimizers`
to use for the Generator.
discriminator_optimizer (:py:class:`tf.optimizers.Optimizer`): A
:py:mod:`tf.optimizers` to use for the Discriminator.
generator_loss (:py:class:`ashpy.losses.executor.Executor`): A ash Executor to compute
the loss of the Generator.
discriminator_loss (:py:class:`ashpy.losses.executor.Executor`): A ash Executor
to compute the loss of the Discriminator.
epochs (int): number of training epochs.
metrics: (List): list of :py:class:`ashpy.metrics.Metric` to measure on
training and validation data.
callbacks (List): list of :py:class:`ashpy.callbacks.Callback` to measure on
training and validation data.
logdir: checkpoint and log directory.
log_eval_mode: models' mode to use when evaluating and logging.
global_step (Optional[:py:class:`tf.Variable`]): tf.Variable that
keeps track of the training steps.
Returns:
:py:obj:`None`
"""
super().__init__(
epochs=epochs,
logdir=logdir,
log_eval_mode=log_eval_mode,
global_step=global_step,
callbacks=callbacks,
example_dim=(2, 1),
)
self._generator = generator
self._discriminator = discriminator
self._generator_loss = generator_loss
self._generator_loss.reduction = tf.losses.Reduction.NONE
self._discriminator_loss = discriminator_loss
self._discriminator_loss.reduction = tf.losses.Reduction.NONE
losses_metrics = (
DiscriminatorLoss(name="ashpy/d_loss", logdir=logdir),
GeneratorLoss(name="ashpy/g_loss", logdir=logdir),
)
if metrics:
metrics = (*metrics, *losses_metrics)
else:
metrics = losses_metrics
super()._update_metrics(metrics)
super()._validate_metrics()
self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
ckpt_dict = {
self.ckpt_id_optimizer_generator: self._generator_optimizer,
self.ckpt_id_optimizer_discriminator: self._discriminator_optimizer,
self.ckpt_id_generator: self._generator,
self.ckpt_id_discriminator: self._discriminator,
}
self._update_checkpoint(ckpt_dict)
# pylint: disable=unidiomatic-typecheck
if type(self) == AdversarialTrainer:
self._restore_or_init()
self._context = GANContext(
generator_model=self._generator,
discriminator_model=self._discriminator,
generator_loss=self._generator_loss,
discriminator_loss=self._discriminator_loss,
log_eval_mode=self._log_eval_mode,
global_step=self._global_step,
checkpoint=self._checkpoint,
metrics=self._metrics,
)
[docs] def _build_and_restore_models(self, dataset: tf.data.Dataset):
restorer = ashpy.restorers.AdversarialRestorer(self._logdir)
(x, _), z = next(iter(dataset.take(1)))
# Invoke model on sample input
self._generator(z)
self._discriminator(x)
restorer.restore_generator(self._generator)
restorer.restore_discriminator(self._discriminator)
self._deferred_restoration = False
[docs] def train_step(self, real_xy, g_inputs):
"""
Train step for the AdversarialTrainer.
Args:
real_xy: input batch as extracted from the input dataset.
(features, label) pair.
g_inputs: batch of generator_input as generated from the input dataset.
Returns:
d_loss, g_loss, fake: discriminator, generator loss values. fake is the
generator output.
"""
real_x, real_y = real_xy
if len(self._generator.inputs) == 2:
g_inputs = [g_inputs, real_y]
with tf.GradientTape(persistent=True) as tape:
fake = self._generator(g_inputs, training=True)
d_loss = self._discriminator_loss(
self._context, fake=fake, real=real_x, condition=real_y, training=True
)
g_loss = self._generator_loss(
self._context, fake=fake, real=real_x, condition=real_y, training=True
)
# check that we have some trainable_variables
assert self._generator.trainable_variables
assert self._discriminator.trainable_variables
# calculate the gradient
d_gradients = tape.gradient(d_loss, self._discriminator.trainable_variables)
g_gradients = tape.gradient(g_loss, self._generator.trainable_variables)
# delete the tape since it's persistent
del tape
# apply the gradient
self._discriminator_optimizer.apply_gradients(
zip(d_gradients, self._discriminator.trainable_variables)
)
self._generator_optimizer.apply_gradients(
zip(g_gradients, self._generator.trainable_variables)
)
return d_loss, g_loss, fake
[docs] @tf.function
def _train_step(self, example):
"""Training step with the distribution strategy."""
ret = self._distribute_strategy.experimental_run_v2(
self.train_step, args=(example[0], example[1])
)
per_replica_losses = ret[:-1]
fake = ret[-1]
return (
self._reduce(per_replica_losses[0], tf.distribute.ReduceOp.SUM),
self._reduce(per_replica_losses[1], tf.distribute.ReduceOp.SUM),
fake,
)
[docs] def call(
self,
dataset: tf.data.Dataset,
log_freq: int = 10,
measure_performance_freq: int = 10,
):
"""
Perform the adversarial training.
Args:
dataset (:py:obj:`tf.data.Dataset`): The adversarial training dataset.
log_freq (int): Specifies how many steps to run before logging the losses,
e.g. `log_frequency=10` logs every 10 steps of training.
Pass `log_frequency<=0` in case you don't want to log.
measure_performance_freq (int): Specifies how many steps to run before
measuring the performance, e.g. `measure_performance_freq=10`
measures performance every 10 steps of training.
Pass `measure_performance_freq<=0` in case you don't want to measure
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)
current_epoch = self._current_epoch()
self._update_global_batch_size(
dataset, [self._discriminator_loss, self._generator_loss]
)
dataset = wrap(
dataset.unbatch().batch(self._global_batch_size, drop_remainder=True)
)
samples = next(iter(dataset.take(1)))
self._context.generator_inputs = samples[1]
with self._train_summary_writer.as_default():
# notify on train start
self._on_train_start()
for _ in tf.range(current_epoch, self._epochs):
distribute_dataset = self._distribute_strategy.experimental_distribute_dataset(
dataset
)
# notify on epoch start
self._on_epoch_start()
for example in distribute_dataset:
# notify on batch start
self._on_batch_start()
# perform training step
d_loss, g_loss, fake = self._train_step(example)
# store fake samples in the context
self._context.fake_samples = fake
self._global_step.assign_add(1)
# print statistics
if log_freq > 0 and tf.equal(
tf.math.mod(self._global_step, log_freq), 0
):
tf.print(
f"[{self._global_step.numpy()}] g_loss: {g_loss} - d_loss: {d_loss}"
)
# measure performance if needed
self._measure_performance_if_needed(
example, measure_performance_freq
)
# notify on batch end
self._on_batch_end()
# notify on epoch end
self._on_epoch_end()
# final callback
self._on_train_end()
[docs]class EncoderTrainer(AdversarialTrainer):
r"""
Primitive Trainer for GANs using an Encoder sub-network.
The implementation is thought to be used with the BCE losses. To use another loss function
consider subclassing the model and overriding the train_step method.
Examples:
.. testcode::
from pathlib import Path
import shutil
import operator
def real_gen():
label = 0
for _ in tf.range(100):
yield ((10.0,), (label,))
latent_dim = 100
generator = tf.keras.Sequential([tf.keras.layers.Dense(1)])
left_input = tf.keras.layers.Input(shape=(1,))
left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input)
right_input = tf.keras.layers.Input(shape=(latent_dim,))
right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input)
net = tf.keras.layers.Concatenate()([left, right])
out = tf.keras.layers.Dense(1)(net)
discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out])
encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)])
# Losses
generator_bce = losses.gan.GeneratorBCE()
encoder_bce = losses.gan.EncoderBCE()
minmax = losses.gan.DiscriminatorMinMax()
epochs = 2
# Fake pre-trained classifier
num_classes = 1
classifier = tf.keras.Sequential(
[tf.keras.layers.Dense(10), tf.keras.layers.Dense(num_classes)]
)
logdir = Path("testlog") / "adversarial_encoder"
if logdir.exists():
shutil.rmtree(logdir)
metrics = [metrics.gan.EncodingAccuracy(classifier)]
trainer = trainers.gan.EncoderTrainer(
generator=generator,
discriminator=discriminator,
encoder=encoder,
discriminator_optimizer=tf.optimizers.Adam(1e-4),
generator_optimizer=tf.optimizers.Adam(1e-5),
encoder_optimizer=tf.optimizers.Adam(1e-6),
generator_loss=generator_bce,
discriminator_loss=minmax,
encoder_loss=encoder_bce,
epochs=epochs,
metrics=metrics,
logdir=logdir,
)
batch_size = 10
discriminator_input = tf.data.Dataset.from_generator(
real_gen, (tf.float32, tf.int64), ((1), (1))
).batch(batch_size)
dataset = discriminator_input.map(
lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
)
trainer(dataset)
shutil.rmtree(logdir)
.. testoutput::
Initializing checkpoint.
Starting epoch 1.
[10] Saved checkpoint: testlog/adversarial_encoder/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[20] Saved checkpoint: testlog/adversarial_encoder/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
"""
ckpt_id_encoder: str = "encoder"
ckpt_id_optimizer_encoder: str = "optimizer_encoder"
[docs] def __init__(
self,
generator: tf.keras.Model,
discriminator: tf.keras.Model,
encoder: tf.keras.Model,
generator_optimizer: tf.optimizers.Optimizer,
discriminator_optimizer: tf.optimizers.Optimizer,
encoder_optimizer: tf.optimizers.Optimizer,
generator_loss: Executor,
discriminator_loss: Executor,
encoder_loss: Executor,
epochs: int,
metrics: Optional[List[Metric]] = None,
callbacks: Optional[List[Callback]] = None,
logdir: Union[Path, str] = Path().cwd() / "log",
log_eval_mode: LogEvalMode = LogEvalMode.TEST,
global_step: Optional[tf.Variable] = None,
):
r"""
Instantiate a :py:class:`EncoderTrainer`.
Args:
generator (:py:class:`tf.keras.Model`): A :py:class:`tf.keras.Model`
describing the Generator part of a GAN.
discriminator (:py:class:`tf.keras.Model`): A :py:class:`tf.keras.Model`
describing the Discriminator part of a GAN.
encoder (:py:class:`tf.keras.Model`): A :py:class:`tf.keras.Model` describing
the Encoder part of a GAN.
generator_optimizer (:py:class:`tf.optimizers.Optimizer`): A :py:mod:`tf.optimizers`
to use for the Generator.
discriminator_optimizer (:py:class:`tf.optimizers.Optimizer`): A :py:mod:`tf.optimizers`
to use for the Discriminator.
encoder_optimizer (:py:class:`tf.optimizers.Optimizer`): A :py:mod:`tf.optimizers`
to use for the Encoder.
generator_loss (:py:class:`ashpy.losses.executor.Executor`): A ash Executor to compute
the loss of the Generator.
discriminator_loss (:py:class:`ashpy.losses.executor.Executor`): A ash Executor to
compute the loss of the Discriminator.
encoder_loss (:py:class:`ashpy.losses.executor.Executor`): A ash Executor to compute
the loss of the Discriminator.
epochs (int): number of training epochs.
metrics: (List): list of ashpy.metrics.Metric to measure on training and
validation data.
callbacks (List): List of ashpy.callbacks.Callback to call on events
logdir: checkpoint and log directory.
log_eval_mode (:py:class:`ashpy.modes.LogEvalMode`): models' mode to use
when evaluating and logging.
global_step (Optional[:py:class:`tf.Variable`]): tf.Variable that keeps
track of the training steps.
"""
if not metrics:
metrics = []
metrics.append(EncoderLoss(name="ashpy/e_loss", logdir=logdir))
super().__init__(
generator=generator,
discriminator=discriminator,
generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator_loss=generator_loss,
discriminator_loss=discriminator_loss,
epochs=epochs,
metrics=metrics,
callbacks=callbacks,
logdir=logdir,
log_eval_mode=log_eval_mode,
global_step=global_step,
)
self._encoder = encoder
self._encoder_optimizer = encoder_optimizer
self._encoder_loss = encoder_loss
self._encoder_loss.reduction = tf.losses.Reduction.NONE
ckpt_dict = {
self.ckpt_id_encoder: self._encoder,
self.ckpt_id_optimizer_encoder: self._encoder_optimizer,
}
self._update_checkpoint(ckpt_dict)
self._restore_or_init()
self._context = GANEncoderContext(
generator_model=self._generator,
discriminator_model=self._discriminator,
encoder_model=self._encoder,
generator_loss=self._generator_loss,
discriminator_loss=self._discriminator_loss,
encoder_loss=self._encoder_loss,
log_eval_mode=self._log_eval_mode,
global_step=self._global_step,
checkpoint=self._checkpoint,
metrics=self._metrics,
)
[docs] def _build_and_restore_models(self, dataset: tf.data.Dataset):
restorer = ashpy.restorers.AdversarialEncoderRestorer(self._logdir)
(x, _), _ = next(iter(dataset.take(1)))
# Invoke model on sample input
self._encoder(x)
restorer.restore_encoder(self._encoder)
super()._build_and_restore_models(dataset)
self._deferred_restoration = False
[docs] def train_step(self, real_xy, g_inputs):
"""Adversarial training step.
Args:
real_xy: input batch as extracted from the discriminator input dataset.
(features, label) pair
g_inputs: batch of noise as generated by the generator input dataset.
Returns:
d_loss, g_loss, e_loss: discriminator, generator, encoder loss values.
"""
real_x, real_y = real_xy
if len(self._generator.inputs) == 2:
g_inputs = [g_inputs, real_y]
with tf.GradientTape(persistent=True) as tape:
fake = self._generator(g_inputs, training=True)
g_loss = self._generator_loss(
self._context, fake=fake, real=real_x, condition=real_y, training=True
)
d_loss = self._discriminator_loss(
self._context, fake=fake, real=real_x, condition=real_y, training=True
)
e_loss = self._encoder_loss(
self._context, fake=fake, real=real_x, condition=real_y, training=True
)
g_gradients = tape.gradient(g_loss, self._generator.trainable_variables)
d_gradients = tape.gradient(d_loss, self._discriminator.trainable_variables)
e_gradients = tape.gradient(e_loss, self._encoder.trainable_variables)
del tape
# Only for logging in special cases (out of tape)
generator_of_encoder = self._generator(
self._encoder(real_x, training=True), training=True
)
self._discriminator_optimizer.apply_gradients(
zip(d_gradients, self._discriminator.trainable_variables)
)
self._generator_optimizer.apply_gradients(
zip(g_gradients, self._generator.trainable_variables)
)
self._encoder_optimizer.apply_gradients(
zip(e_gradients, self._encoder.trainable_variables)
)
return d_loss, g_loss, e_loss, fake, generator_of_encoder
[docs] @tf.function
def _train_step(self, example):
"""Perform the training step using the distribution strategy."""
ret = self._distribute_strategy.experimental_run_v2(
self.train_step, args=(example[0], example[1])
)
per_replica_losses = ret[:3]
fake = ret[3]
generator_of_encoder = ret[4]
return (
self._reduce(per_replica_losses[0], tf.distribute.ReduceOp.SUM),
self._reduce(per_replica_losses[1], tf.distribute.ReduceOp.SUM),
self._reduce(per_replica_losses[2], tf.distribute.ReduceOp.SUM),
fake,
generator_of_encoder,
)
[docs] def call(
self,
dataset: tf.data.Dataset,
log_freq: int = 10,
measure_performance_freq: int = 10,
):
r"""
Perform the adversarial training.
Args:
dataset (:py:class:`tf.data.Dataset`): The adversarial training dataset.
log_freq (int): Specifies how many steps to run before logging the losses,
e.g. `log_frequency=10` logs every 10 steps of training.
Pass `log_frequency<=0` in case you don't want to log.
measure_performance_freq (int): Specifies how many steps to run before
measuring the performance, e.g. `measure_performance_freq=10`
measures performance every 10 steps of training.
Pass `measure_performance_freq<=0` in case you don't want to measure
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)
current_epoch = self._current_epoch()
self._update_global_batch_size(
dataset,
[self._discriminator_loss, self._generator_loss, self._encoder_loss],
)
dataset = wrap(
dataset.unbatch().batch(self._global_batch_size, drop_remainder=True)
)
samples = next(iter(dataset.take(1)))
self._context.generator_inputs = samples[1]
self._context.encoder_inputs = samples[0][0]
with self._train_summary_writer.as_default():
# notify on train start event
self._on_train_start()
for _ in tf.range(current_epoch, self._epochs):
distribute_dataset = self._distribute_strategy.experimental_distribute_dataset(
dataset
)
# notify on epoch start event
self._on_epoch_start()
for example in distribute_dataset:
# perform training step
(
d_loss,
g_loss,
e_loss,
fake,
generator_of_encoder,
) = self._train_step(example)
# increase global step
self._global_step.assign_add(1)
# setup fake_samples
self._context.fake_samples = fake
self._context.generator_of_encoder = generator_of_encoder
# Log losses
if log_freq > 0 and tf.equal(
tf.math.mod(self._global_step, log_freq), 0
):
tf.print(
f"[{self._global_step.numpy()}] g_loss: {g_loss} - "
f"d_loss: {d_loss} - e_loss: {e_loss}"
)
# measure performance if needed
self._measure_performance_if_needed(
example, measure_performance_freq
)
# notify on batch end event
self._on_batch_end()
# notify on epoch end event
self._on_epoch_end()
# notify on training end event
self._on_train_end()