gan

GAN callbacks.

LogImageGANCallback: Log output of the generator when evaluated in its inputs. LogImageGANEncoderCallback: Log output of the generator when evaluated in the encoder

Classes

LogImageGANCallback Callback used for logging GANs images to Tensorboard.
LogImageGANEncoderCallback Callback used for logging GANs images to Tensorboard.
class ashpy.callbacks.gan.LogImageGANCallback(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_callback', event_freq=1)[source]

Bases: ashpy.callbacks.counter_callback.CounterCallback

Callback used for logging GANs images to Tensorboard.

Logs the Generator output. Logs G(z).

Examples

import shutil
import operator
from pathlib import Path

generator = models.gans.ConvGenerator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = models.gans.ConvDiscriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

# Losses
generator_bce = losses.gan.GeneratorBCE()
minmax = losses.gan.DiscriminatorMinMax()

# Real data
batch_size = 2
mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,))

# Trainer
epochs = 2
logdir = Path("testlog/callbacks")
callbacks = [callbacks.LogImageGANCallback()]
trainer = trainers.gan.AdversarialTrainer(
    generator=generator,
    discriminator=discriminator,
    generator_optimizer=tf.optimizers.Adam(1e-4),
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    epochs=epochs,
    callbacks=callbacks,
    logdir=logdir,
)

# take only 2 samples to speed up tests
real_data = (
    tf.data.Dataset.from_tensor_slices(
    (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size)
    .batch(batch_size)
    .prefetch(1)
)

# Add noise in the same dataset, just by mapping.
# The return type of the dataset must be: tuple(tuple(a,b), noise)
dataset = real_data.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100)))
)

trainer(dataset)
shutil.rmtree(logdir)

assert not logdir.exists()

trainer._global_step.assign_add(500)
Initializing checkpoint.
Starting epoch 1.
[1] Saved checkpoint: testlog/callbacks/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[2] Saved checkpoint: testlog/callbacks/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
__init__(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_callback', event_freq=1)[source]

Initialize the LogImageCallbackGAN.

Parameters:
Return type:

None

_log_fn(context)[source]

Log output of the generator to Tensorboard.

Parameters:context (ashpy.contexts.gan.GANContext) – current context.
Return type:None
class ashpy.callbacks.gan.LogImageGANEncoderCallback(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_encoder_callback', event_freq=1)[source]

Bases: ashpy.callbacks.gan.LogImageGANCallback

Callback used for logging GANs images to Tensorboard.

Logs the Generator output evaluated in the encoder output. Logs G(E(x)).

Examples

import shutil
import operator

def real_gen():
    label = 0
    for _ in tf.range(100):
        yield ((10.0,), (label,))

latent_dim = 100

generator = tf.keras.Sequential([tf.keras.layers.Dense(1)])

left_input = tf.keras.layers.Input(shape=(1,))
left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input)

right_input = tf.keras.layers.Input(shape=(latent_dim,))
right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input)

net = tf.keras.layers.Concatenate()([left, right])
out = tf.keras.layers.Dense(1)(net)

discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out])

encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)])

# Losses
generator_bce = losses.gan.GeneratorBCE()
encoder_bce = losses.gan.EncoderBCE()
minmax = losses.gan.DiscriminatorMinMax()

epochs = 2

callbacks = [callbacks.LogImageGANEncoderCallback()]

logdir = "testlog/callbacks_encoder"

trainer = trainers.gan.EncoderTrainer(
    generator=generator,
    discriminator=discriminator,
    encoder=encoder,
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_optimizer=tf.optimizers.Adam(1e-5),
    encoder_optimizer=tf.optimizers.Adam(1e-6),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    encoder_loss=encoder_bce,
    epochs=epochs,
    callbacks=callbacks,
    logdir=logdir,
)

batch_size = 10
discriminator_input = tf.data.Dataset.from_generator(
    real_gen, (tf.float32, tf.int64), ((1), (1))
).batch(batch_size)

dataset = discriminator_input.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
)

trainer(dataset)

shutil.rmtree(logdir)
Initializing checkpoint.
Starting epoch 1.
[10] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[20] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
__init__(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_encoder_callback', event_freq=1)[source]

Initialize the LogImageGANEncoderCallback.

Parameters:
Return type:

None

_log_fn(context)[source]

Log output of the generator to Tensorboard.

Logs G(E(x)).

Parameters:context (ashpy.contexts.gan.GanEncoderContext) – current context.