LogImageGANCallback¶
Inheritance Diagram
-
class
ashpy.callbacks.gan.
LogImageGANCallback
(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_callback', event_freq=1)[source]¶ Bases:
ashpy.callbacks.counter_callback.CounterCallback
Callback used for logging GANs images to Tensorboard.
Logs the Generator output. Logs G(z).
Examples
import shutil import operator from pathlib import Path generator = models.gans.ConvGenerator( layer_spec_input_res=(7, 7), layer_spec_target_res=(28, 28), kernel_size=(5, 5), initial_filters=32, filters_cap=16, channels=1, ) discriminator = models.gans.ConvDiscriminator( layer_spec_input_res=(28, 28), layer_spec_target_res=(7, 7), kernel_size=(5, 5), initial_filters=16, filters_cap=32, output_shape=1, ) # Losses generator_bce = losses.gan.GeneratorBCE() minmax = losses.gan.DiscriminatorMinMax() # Real data batch_size = 2 mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,)) # Trainer epochs = 2 logdir = Path("testlog/callbacks") callbacks = [callbacks.LogImageGANCallback()] trainer = trainers.gan.AdversarialTrainer( generator=generator, discriminator=discriminator, generator_optimizer=tf.optimizers.Adam(1e-4), discriminator_optimizer=tf.optimizers.Adam(1e-4), generator_loss=generator_bce, discriminator_loss=minmax, epochs=epochs, callbacks=callbacks, logdir=logdir, ) # take only 2 samples to speed up tests real_data = ( tf.data.Dataset.from_tensor_slices( (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size) .batch(batch_size) .prefetch(1) ) # Add noise in the same dataset, just by mapping. # The return type of the dataset must be: tuple(tuple(a,b), noise) dataset = real_data.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100))) ) trainer(dataset) shutil.rmtree(logdir) assert not logdir.exists() trainer._global_step.assign_add(500)
Initializing checkpoint. Starting epoch 1. [1] Saved checkpoint: testlog/callbacks/ckpts/ckpt-1 Epoch 1 completed. Starting epoch 2. [2] Saved checkpoint: testlog/callbacks/ckpts/ckpt-2 Epoch 2 completed. Training finished after 2 epochs.
Methods
__init__
([event, name, event_freq])Initialize the LogImageCallbackGAN. Attributes
name
Return the name of the callback. name_scope
Returns a tf.name_scope instance for this class. submodules
Sequence of all sub-modules. trainable_variables
Sequence of variables owned by this module and it’s submodules. variables
Sequence of variables owned by this module and it’s submodules. -
__init__
(event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, name='log_image_gan_callback', event_freq=1)[source]¶ Initialize the LogImageCallbackGAN.
Parameters: - event (
ashpy.callbacks.events.Event
) – event to consider. - event_freq (int) – frequency of logging.
- name (str) – name of the callback.
Return type: None
- event (
-
_log_fn
(context)[source]¶ Log output of the generator to Tensorboard.
Parameters: context ( ashpy.contexts.gan.GANContext
) – current context.Return type: None
-