LogImageGANEncoderCallback¶
Inheritance Diagram
-
class
ashpy.callbacks.gan.
LogImageGANEncoderCallback
(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_encoder_callback', event_freq=1)[source]¶ Bases:
ashpy.callbacks.gan.LogImageGANCallback
Callback used for logging GANs images to Tensorboard.
Logs the Generator output evaluated in the encoder output. Logs G(E(x)).
Examples
import shutil import operator def real_gen(): label = 0 for _ in tf.range(100): yield ((10.0,), (label,)) latent_dim = 100 generator = tf.keras.Sequential([tf.keras.layers.Dense(1)]) left_input = tf.keras.layers.Input(shape=(1,)) left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input) right_input = tf.keras.layers.Input(shape=(latent_dim,)) right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input) net = tf.keras.layers.Concatenate()([left, right]) out = tf.keras.layers.Dense(1)(net) discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out]) encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)]) # Losses generator_bce = losses.gan.GeneratorBCE() encoder_bce = losses.gan.EncoderBCE() minmax = losses.gan.DiscriminatorMinMax() epochs = 2 callbacks = [callbacks.LogImageGANEncoderCallback()] logdir = "testlog/callbacks_encoder" trainer = trainers.gan.EncoderTrainer( generator=generator, discriminator=discriminator, encoder=encoder, discriminator_optimizer=tf.optimizers.Adam(1e-4), generator_optimizer=tf.optimizers.Adam(1e-5), encoder_optimizer=tf.optimizers.Adam(1e-6), generator_loss=generator_bce, discriminator_loss=minmax, encoder_loss=encoder_bce, epochs=epochs, callbacks=callbacks, logdir=logdir, ) batch_size = 10 discriminator_input = tf.data.Dataset.from_generator( real_gen, (tf.float32, tf.int64), ((1), (1)) ).batch(batch_size) dataset = discriminator_input.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim))) ) trainer(dataset) shutil.rmtree(logdir)
Initializing checkpoint. Starting epoch 1. [10] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-1 Epoch 1 completed. Starting epoch 2. [20] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-2 Epoch 2 completed. Training finished after 2 epochs.
Methods
__init__
([event, name, event_freq])Initialize the LogImageGANEncoderCallback. Attributes
name
Return the name of the callback. name_scope
Returns a tf.name_scope instance for this class. submodules
Sequence of all sub-modules. trainable_variables
Sequence of variables owned by this module and it’s submodules. variables
Sequence of variables owned by this module and it’s submodules. -
__init__
(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_encoder_callback', event_freq=1)[source]¶ Initialize the LogImageGANEncoderCallback.
Parameters: - event (
ashpy.callbacks.events.Event
) – event to consider. - event_freq (int) – frequency of logging.
- name (str) – name of the callback.
Return type: None
- event (
-