LogImageGANEncoderCallback

Inheritance Diagram

Inheritance diagram of ashpy.callbacks.gan.LogImageGANEncoderCallback

class ashpy.callbacks.gan.LogImageGANEncoderCallback(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANEncoderCallback', event_freq=1)[source]

Bases: ashpy.callbacks.gan.LogImageGANCallback

Callback used for logging GANs images to Tensorboard.

Logs the Generator output evaluated in the encoder output. Logs G(E(x)).

Examples

import shutil
import operator

def real_gen():
    label = 0
    for _ in tf.range(100):
        yield ((10.0,), (label,))

latent_dim = 100

generator = tf.keras.Sequential([tf.keras.layers.Dense(1)])

left_input = tf.keras.layers.Input(shape=(1,))
left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input)

right_input = tf.keras.layers.Input(shape=(latent_dim,))
right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input)

net = tf.keras.layers.Concatenate()([left, right])
out = tf.keras.layers.Dense(1)(net)

discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out])

encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)])

# Losses
generator_bce = losses.gan.GeneratorBCE()
encoder_bce = losses.gan.EncoderBCE()
minmax = losses.gan.DiscriminatorMinMax()

epochs = 2

callbacks = [callbacks.LogImageGANEncoderCallback()]

logdir = "testlog/callbacks_encoder"

trainer = trainers.gan.EncoderTrainer(
    generator=generator,
    discriminator=discriminator,
    encoder=encoder,
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_optimizer=tf.optimizers.Adam(1e-5),
    encoder_optimizer=tf.optimizers.Adam(1e-6),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    encoder_loss=encoder_bce,
    epochs=epochs,
    callbacks=callbacks,
    logdir=logdir,
)

batch_size = 10
discriminator_input = tf.data.Dataset.from_generator(
    real_gen, (tf.float32, tf.int64), ((1), (1))
).batch(batch_size)

dataset = discriminator_input.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
)

trainer(dataset)

shutil.rmtree(logdir)
Initializing checkpoint.
Starting epoch 1.
[10] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[20] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.

Methods

__init__([event, name, event_freq]) Initialize the LogImageCallbackGAN.

Attributes

name Returns the name of this module as passed or determined in the ctor.
name_scope Returns a tf.name_scope instance for this class.
submodules Sequence of all sub-modules.
trainable_variables Sequence of variables owned by this module and it’s submodules.
variables Sequence of variables owned by this module and it’s submodules.
__init__(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANEncoderCallback', event_freq=1)[source]

Initialize the LogImageCallbackGAN.

Parameters:
_log_fn(context)[source]

Log output of the generator to Tensorboard.

Logs G(E(x)).

Parameters:context (ashpy.contexts.gan.GanEncoderContext) – current context.