gan

Collection of GANs trainers.

Classes

AdversarialTrainer

Primitive Trainer for GANs subclassed from ashpy.trainers.BaseTrainer.

EncoderTrainer

Primitive Trainer for GANs using an Encoder sub-network.

class ashpy.trainers.gan.AdversarialTrainer(generator, discriminator, generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]

Bases: ashpy.trainers.base_trainer.BaseTrainer

Primitive Trainer for GANs subclassed from ashpy.trainers.BaseTrainer.

Examples

import shutil
import operator
from ashpy.models.gans import Generator, Discriminator
from ashpy.metrics import InceptionScore
from ashpy.losses.gan import DiscriminatorMinMax, GeneratorBCE

generator = Generator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = Discriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

# Losses
generator_bce = GeneratorBCE()
minmax = DiscriminatorMinMax()

# Real data
batch_size = 2
mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,))

# Trainer
epochs = 2
logdir = "testlog/adversarial"
metrics = [
    InceptionScore(
        # Fake inception model
        Discriminator(
            layer_spec_input_res=(299, 299),
            layer_spec_target_res=(7, 7),
            kernel_size=(5, 5),
            initial_filters=16,
            filters_cap=32,
            output_shape=10,
        ),
        #model_selection_operator=operator.gt,
        logdir=logdir,
    )
]
trainer = AdversarialTrainer(
    generator,
    discriminator,
    tf.optimizers.Adam(1e-4),
    tf.optimizers.Adam(1e-4),
    generator_bce,
    minmax,
    epochs,
    metrics,
    logdir,
)

# Dataset
noise_dataset = tf.data.Dataset.from_tensors(0).repeat().map(
    lambda _: tf.random.normal(shape=(100,), dtype=tf.float32, mean=0.0, stddev=1)
).batch(batch_size).prefetch(1)

# take only 2 samples to speed up tests
real_data = (
    tf.data.Dataset.from_tensor_slices(
    (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size)
    .batch(batch_size)
    .prefetch(1)
)

# Add noise in the same dataset, just by mapping.
# The return type of the dataset must be: tuple(tuple(a,b), noise)
dataset = real_data.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100)))
)

trainer(dataset)
shutil.rmtree(logdir)
Initializing checkpoint.
[1] Saved checkpoint: testlog/adversarial/ckpts/ckpt-1
Epoch 1 completed.
[2] Saved checkpoint: testlog/adversarial/ckpts/ckpt-2
Epoch 2 completed.
__init__(generator, discriminator, generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]

Instantiate a AdversarialTrainer.

Parameters
  • generator (tf.keras.Model) – A tf.keras.Model describing the Generator part of a GAN.

  • discriminator (tf.keras.Model) – A tf.keras.Model describing the Discriminator part of a GAN.

  • generator_optimizer (tf.optimizers.Optimizer) – A tf.optimizers to use for the Generator.

  • discriminator_optimizer (tf.optimizers.Optimizer) – A tf.optimizers to use for the Discriminator.

  • generator_loss (ashpy.losses.executor.Executor) – A ash Executor to compute the loss of the Generator.

  • discriminator_loss (ashpy.losses.executor.Executor) – A ash Executor to compute the loss of the Discriminator.

  • epochs (int) – number of training epochs.

  • metrics – (List): list of tf.metrics to measure on training and validation data

  • logdir – checkpoint and log directory.

  • post_process_callback (callable) – the function to postprocess the model output, if needed

  • log_eval_mode – models’ mode to use when evaluating and logging.

  • global_step – tf.Variable that keeps track of the training steps.

Returns

None

_measure_performance(dataset)[source]

Measure performance on dataset.

Parameters

dataset (tf.data.Dataset) –

_train_step[source]

Training step with the distribution strategy.

call(dataset)[source]

Perform the adversarial training.

Parameters

dataset (tf.data.Dataset) – The adversarial training dataset.

train_step(real_xy, g_inputs)[source]

Train step for the AdversarialTrainer.

Parameters
  • real_xy – input batch as extracted from the input dataset. (features, label) pair.

  • g_inputs – batch of generator_input as generated from the input dataset.

Returns

d_loss, g_loss, fake

discriminator, generator loss values. fake is the

generator output.

class ashpy.trainers.gan.EncoderTrainer(generator, discriminator, encoder, generator_optimizer, discriminator_optimizer, encoder_optimizer, generator_loss, discriminator_loss, encoder_loss, epochs, metrics, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]

Bases: ashpy.trainers.gan.AdversarialTrainer

Primitive Trainer for GANs using an Encoder sub-network. The implementation is thought to be used with the BCE losses. To use another loss function consider subclassing the model and overriding the train_step method.

Examples

import shutil
import operator
from ashpy.metrics import EncodingAccuracy
from ashpy.losses.gan import DiscriminatorMinMax, GeneratorBCE, EncoderBCE

def real_gen():
    label = 0
    for _ in tf.range(100):
        yield ((10.0,), (label,))

latent_dim = 100

generator = tf.keras.Sequential([tf.keras.layers.Dense(1)])

left_input = tf.keras.layers.Input(shape=(1,))
left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input)

right_input = tf.keras.layers.Input(shape=(latent_dim,))
right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input)

net = tf.keras.layers.Concatenate()([left, right])
out = tf.keras.layers.Dense(1)(net)

discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out])

encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)])

# Losses
generator_bce = GeneratorBCE()
encoder_bce = EncoderBCE()
minmax = DiscriminatorMinMax()

epochs = 2

# Fake pre-trained classifier
num_classes = 1
classifier = tf.keras.Sequential(
    [tf.keras.layers.Dense(10), tf.keras.layers.Dense(num_classes)]
)

logdir = "testlog/adversarial/encoder"

metrics = [
    EncodingAccuracy(
        classifier,
        # model_selection_operator=operator.gt,
        logdir=logdir
    )
]

trainer = EncoderTrainer(
    generator=generator,
    discriminator=discriminator,
    encoder=encoder,
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_optimizer=tf.optimizers.Adam(1e-5),
    encoder_optimizer=tf.optimizers.Adam(1e-6),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    encoder_loss=encoder_bce,
    epochs=epochs,
    metrics=metrics,
    logdir=logdir,
)

batch_size = 10
discriminator_input = tf.data.Dataset.from_generator(
    real_gen, (tf.float32, tf.int64), ((1), (1))
).batch(batch_size)

dataset = discriminator_input.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
)

trainer(dataset)

shutil.rmtree(logdir)
Initializing checkpoint.
[10] Saved checkpoint: testlog/adversarial/encoder/ckpts/ckpt-1
Epoch 1 completed.
[20] Saved checkpoint: testlog/adversarial/encoder/ckpts/ckpt-2
Epoch 2 completed.
__init__(generator, discriminator, encoder, generator_optimizer, discriminator_optimizer, encoder_optimizer, generator_loss, discriminator_loss, encoder_loss, epochs, metrics, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]

Instantiate a EncoderTrainer.

Parameters
  • generator (tf.keras.Model) – A tf.keras.Model describing the Generator part of a GAN.

  • discriminator (tf.keras.Model) – A tf.keras.Model describing the Discriminator part of a GAN.

  • encoder (tf.keras.Model) – A tf.keras.Model describing the Encoder part of a GAN.

  • generator_optimizer (tf.optimizers.Optimizer) – A tf.optimizers to use for the Generator.

  • discriminator_optimizer (tf.optimizers.Optimizer) – A tf.optimizers to use for the Discriminator.

  • encoder_optimizer (tf.optimizers.Optimizer) – A tf.optimizers to use for the Encoder.

  • generator_loss (ashpy.losses.executor.Executor) – A ash Executor to compute the loss of the Generator.

  • discriminator_loss (ashpy.losses.executor.Executor) – A ash Executor to compute the loss of the Discriminator.

  • encoder_loss (ashpy.losses.executor.Executor) – A ash Executor to compute the loss of the Discriminator.

  • epochs (int) – number of training epochs.

  • metrics – (List): list of tf.metrics to measure on training and validation data.

  • logdir – checkpoint and log directory.

  • post_process_callback (callable) – a function to post-process the output.

  • log_eval_mode – models’ mode to use when evaluating and logging.

  • global_step – tf.Variable that keeps track of the training steps.

_measure_performance(dataset)[source]

Measure performance on dataset.

Parameters

dataset (tf.data.Dataset) –

_train_step[source]

The training step that uses the distribution strategy.

call(dataset)[source]

Perform the adversarial training.

Parameters

dataset (tf.data.Dataset) – The adversarial training dataset.

train_step(real_xy, g_inputs)[source]

Adversarial training step.

Parameters
  • real_xy – input batch as extracted from the discriminator input dataset. (features, label) pair

  • g_inputs – batch of noise as generated by the generator input dataset.

Returns

d_loss, g_loss, e_loss – discriminator, generator, encoder loss values.