LogImageGANCallback

Inheritance Diagram

Inheritance diagram of ashpy.callbacks.gan.LogImageGANCallback

class ashpy.callbacks.gan.LogImageGANCallback(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_callback', event_freq=1)[source]

Bases: ashpy.callbacks.counter_callback.CounterCallback

Callback used for logging GANs images to Tensorboard.

Logs the Generator output. Logs G(z).

Examples

import shutil
import operator
from pathlib import Path

generator = models.gans.ConvGenerator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = models.gans.ConvDiscriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

# Losses
generator_bce = losses.gan.GeneratorBCE()
minmax = losses.gan.DiscriminatorMinMax()

# Real data
batch_size = 2
mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,))

# Trainer
epochs = 2
logdir = Path("testlog/callbacks")
callbacks = [callbacks.LogImageGANCallback()]
trainer = trainers.gan.AdversarialTrainer(
    generator=generator,
    discriminator=discriminator,
    generator_optimizer=tf.optimizers.Adam(1e-4),
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    epochs=epochs,
    callbacks=callbacks,
    logdir=logdir,
)

# take only 2 samples to speed up tests
real_data = (
    tf.data.Dataset.from_tensor_slices(
    (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size)
    .batch(batch_size)
    .prefetch(1)
)

# Add noise in the same dataset, just by mapping.
# The return type of the dataset must be: tuple(tuple(a,b), noise)
dataset = real_data.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100)))
)

trainer(dataset)
shutil.rmtree(logdir)

assert not logdir.exists()

trainer._global_step.assign_add(500)
Initializing checkpoint.
Starting epoch 1.
[1] Saved checkpoint: testlog/callbacks/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[2] Saved checkpoint: testlog/callbacks/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.

Methods

__init__([event, name, event_freq]) Initialize the LogImageCallbackGAN.

Attributes

name Return the name of the callback.
name_scope Returns a tf.name_scope instance for this class.
submodules Sequence of all sub-modules.
trainable_variables Sequence of variables owned by this module and it’s submodules.
variables Sequence of variables owned by this module and it’s submodules.
__init__(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_callback', event_freq=1)[source]

Initialize the LogImageCallbackGAN.

Parameters:
Return type:

None

_log_fn(context)[source]

Log output of the generator to Tensorboard.

Parameters:context (ashpy.contexts.gan.GANContext) – current context.
Return type:None