LogImageGANEncoderCallback¶
Inheritance Diagram
-
class
ashpy.callbacks.gan.
LogImageGANEncoderCallback
(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANEncoderCallback', event_freq=1)[source]¶ Bases:
ashpy.callbacks.gan.LogImageGANCallback
Callback used for logging GANs images to Tensorboard.
Logs the Generator output evaluated in the encoder output. Logs G(E(x)).
Examples
import shutil import operator def real_gen(): label = 0 for _ in tf.range(100): yield ((10.0,), (label,)) latent_dim = 100 generator = tf.keras.Sequential([tf.keras.layers.Dense(1)]) left_input = tf.keras.layers.Input(shape=(1,)) left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input) right_input = tf.keras.layers.Input(shape=(latent_dim,)) right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input) net = tf.keras.layers.Concatenate()([left, right]) out = tf.keras.layers.Dense(1)(net) discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out]) encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)]) # Losses generator_bce = losses.gan.GeneratorBCE() encoder_bce = losses.gan.EncoderBCE() minmax = losses.gan.DiscriminatorMinMax() epochs = 2 callbacks = [callbacks.LogImageGANEncoderCallback()] logdir = "testlog/callbacks_encoder" trainer = trainers.gan.EncoderTrainer( generator=generator, discriminator=discriminator, encoder=encoder, discriminator_optimizer=tf.optimizers.Adam(1e-4), generator_optimizer=tf.optimizers.Adam(1e-5), encoder_optimizer=tf.optimizers.Adam(1e-6), generator_loss=generator_bce, discriminator_loss=minmax, encoder_loss=encoder_bce, epochs=epochs, callbacks=callbacks, logdir=logdir, ) batch_size = 10 discriminator_input = tf.data.Dataset.from_generator( real_gen, (tf.float32, tf.int64), ((1), (1)) ).batch(batch_size) dataset = discriminator_input.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim))) ) trainer(dataset) shutil.rmtree(logdir)
Initializing checkpoint. Starting epoch 1. [10] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-1 Epoch 1 completed. Starting epoch 2. [20] Saved checkpoint: testlog/callbacks_encoder/ckpts/ckpt-2 Epoch 2 completed. Training finished after 2 epochs.
Methods
__init__
([event, name, event_freq])Initialize the LogImageCallbackGAN.
Attributes
name
Returns the name of this module as passed or determined in the ctor.
name_scope
Returns a tf.name_scope instance for this class.
submodules
Sequence of all sub-modules.
trainable_variables
Sequence of variables owned by this module and it’s submodules.
variables
Sequence of variables owned by this module and it’s submodules.
-
__init__
(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='LogImageGANEncoderCallback', event_freq=1)[source]¶ Initialize the LogImageCallbackGAN.
- Parameters
event (
ashpy.callbacks.events.Event
) – event to consider.event_freq (int) – frequency of logging.
name (str) – name of the callback.
-