gan

GAN callbacks.

LogImageGANCallback: Log output of the generator when evaluated in its inputs. LogImageGANEncoderCallback: Log output of the generator when evaluated in the encoder

Classes

LogImageGANCallback

Callback used for logging GANs images to Tensorboard.

LogImageGANEncoderCallback

Callback used for logging GANs images to Tensorboard.

class ashpy.callbacks.gan.LogImageGANCallback(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANCallback', event_freq=1)[source]

Bases: ashpy.callbacks.counter_callback.CounterCallback

Callback used for logging GANs images to Tensorboard.

Logs the Generator output. Logs G(z).

Examples

import shutil
import operator
import os

generator = models.gans.ConvGenerator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = models.gans.ConvDiscriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

# Losses
generator_bce = losses.gan.GeneratorBCE()
minmax = losses.gan.DiscriminatorMinMax()

# Real data
batch_size = 2
mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,))

# Trainer
epochs = 2
logdir = "testlog/callbacks"
callbacks = [callbacks.LogImageGANCallback()]
trainer = trainers.gan.AdversarialTrainer(
    generator=generator,
    discriminator=discriminator,
    generator_optimizer=tf.optimizers.Adam(1e-4),
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    epochs=epochs,
    callbacks=callbacks,
    logdir=logdir,
)

# take only 2 samples to speed up tests
real_data = (
    tf.data.Dataset.from_tensor_slices(
    (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size)
    .batch(batch_size)
    .prefetch(1)
)

# Add noise in the same dataset, just by mapping.
# The return type of the dataset must be: tuple(tuple(a,b), noise)
dataset = real_data.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100)))
)

trainer(dataset)
shutil.rmtree(logdir)

assert not os.path.exists(logdir)

trainer._global_step.assign_add(500)
Initializing checkpoint.
Starting epoch 1.
[1] Saved checkpoint: testlog/callbacks/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[2] Saved checkpoint: testlog/callbacks/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
__init__(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANCallback', event_freq=1)[source]

Initialize the LogImageCallbackGAN.

Parameters
_log_fn(context)[source]

Log output of the generator to Tensorboard.

Parameters

context (ashpy.contexts.gan.GANContext) – current context.

Return type

None

class ashpy.callbacks.gan.LogImageGANEncoderCallback(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANEncoderCallback', event_freq=1)[source]

Bases: ashpy.callbacks.gan.LogImageGANCallback

Callback used for logging GANs images to Tensorboard.

Logs the Generator output evaluated in the encoder output. Logs G(E(x)).

Examples

import shutil
import operator

def real_gen():
    label = 0
    for _ in tf.range(100):
        yield ((10.0,), (label,))

latent_dim = 100

generator = tf.keras.Sequential([tf.keras.layers.Dense(1)])

left_input = tf.keras.layers.Input(shape=(1,))
left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input)

right_input = tf.keras.layers.Input(shape=(latent_dim,))
right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input)

net = tf.keras.layers.Concatenate()([left, right])
out = tf.keras.layers.Dense(1)(net)

discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out])

encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)])

# Losses
generator_bce = losses.gan.GeneratorBCE()
encoder_bce = losses.gan.EncoderBCE()
minmax = losses.gan.DiscriminatorMinMax()

epochs = 2

callbacks = [callbacks.LogImageGANEncoderCallback()]

logdir = "testlog/callbacks_encoder"

trainer = trainers.gan.EncoderTrainer(
    generator=generator,
    discriminator=discriminator,
    encoder=encoder,
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_optimizer=tf.optimizers.Adam(1e-5),
    encoder_optimizer=tf.optimizers.Adam(1e-6),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    encoder_loss=encoder_bce,
    epochs=epochs,
    callbacks=callbacks,
    logdir=logdir,
)

batch_size = 10
discriminator_input = tf.data.Dataset.from_generator(
    real_gen, (tf.float32, tf.int64), ((1), (1))
).batch(batch_size)

dataset = discriminator_input.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
)

trainer(dataset)

shutil.rmtree(logdir)
Initializing checkpoint.
Starting epoch 1.
[10] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[20] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
__init__(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANEncoderCallback', event_freq=1)[source]

Initialize the LogImageCallbackGAN.

Parameters
_log_fn(context)[source]

Log output of the generator to Tensorboard.

Logs G(E(x)).

Parameters

context (ashpy.contexts.gan.GanEncoderContext) – current context.