AdversarialTrainer¶
Inheritance Diagram

-
class
ashpy.trainers.AdversarialTrainer(generator, discriminator, generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss, epochs, metrics=None, callbacks=None, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/stable/docs/source/log'), log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None)[source]¶ Bases:
ashpy.trainers.trainer.TrainerPrimitive Trainer for GANs subclassed from
ashpy.trainers.Trainer.Examples
import shutil import operator generator = models.gans.ConvGenerator( layer_spec_input_res=(7, 7), layer_spec_target_res=(28, 28), kernel_size=(5, 5), initial_filters=32, filters_cap=16, channels=1, ) discriminator = models.gans.ConvDiscriminator( layer_spec_input_res=(28, 28), layer_spec_target_res=(7, 7), kernel_size=(5, 5), initial_filters=16, filters_cap=32, output_shape=1, ) # Losses generator_bce = losses.gan.GeneratorBCE() minmax = losses.gan.DiscriminatorMinMax() # Real data batch_size = 2 mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,)) # Trainer epochs = 2 logdir = "testlog/adversarial" metrics = [ metrics.gan.InceptionScore( # Fake inception model models.gans.ConvDiscriminator( layer_spec_input_res=(299, 299), layer_spec_target_res=(7, 7), kernel_size=(5, 5), initial_filters=16, filters_cap=32, output_shape=10, ), model_selection_operator=operator.gt, ) ] trainer = trainers.gan.AdversarialTrainer( generator=generator, discriminator=discriminator, generator_optimizer=tf.optimizers.Adam(1e-4), discriminator_optimizer=tf.optimizers.Adam(1e-4), generator_loss=generator_bce, discriminator_loss=minmax, epochs=epochs, metrics=metrics, logdir=logdir, ) # take only 2 samples to speed up tests real_data = ( tf.data.Dataset.from_tensor_slices( (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size) .batch(batch_size) .prefetch(1) ) # Add noise in the same dataset, just by mapping. # The return type of the dataset must be: tuple(tuple(a,b), noise) dataset = real_data.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100))) ) trainer(dataset) shutil.rmtree(logdir)
Initializing checkpoint. Starting epoch 1. [1] Saved checkpoint: testlog/adversarial/ckpts/ckpt-1 Epoch 1 completed. Starting epoch 2. [2] Saved checkpoint: testlog/adversarial/ckpts/ckpt-2 Epoch 2 completed. Training finished after 2 epochs.
Methods
__init__(generator, discriminator, …[, …])Instantiate a AdversarialTrainer.call(dataset[, log_freq, …])Perform the adversarial training. train_step(real_xy, g_inputs)Train step for the AdversarialTrainer. Attributes
ckpt_id_callbacksckpt_id_discriminatorckpt_id_generatorckpt_id_global_stepckpt_id_optimizer_discriminatorckpt_id_optimizer_generatorckpt_id_steps_per_epochcontextReturn the training context. -
__init__(generator, discriminator, generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss, epochs, metrics=None, callbacks=None, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/stable/docs/source/log'), log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None)[source]¶ Instantiate a
AdversarialTrainer.Parameters: - generator (
tf.keras.Model) – Atf.keras.Modeldescribing the Generator part of a GAN. - discriminator (
tf.keras.Model) – Atf.keras.Modeldescribing the Discriminator part of a GAN. - generator_optimizer (
tf.optimizers.Optimizer) – Atf.optimizersto use for the Generator. - discriminator_optimizer (
tf.optimizers.Optimizer) – Atf.optimizersto use for the Discriminator. - generator_loss (
ashpy.losses.executor.Executor) – A ash Executor to compute the loss of the Generator. - discriminator_loss (
ashpy.losses.executor.Executor) – A ash Executor to compute the loss of the Discriminator. - epochs (int) – number of training epochs.
- metrics (
Union[Tuple[Metric],List[Metric],None]) – (List): list ofashpy.metrics.Metricto measure on training and validation data. - callbacks (List) – list of
ashpy.callbacks.Callbackto measure on training and validation data. - logdir (
Union[Path,str]) – checkpoint and log directory. - log_eval_mode (
LogEvalMode) – models’ mode to use when evaluating and logging. - global_step (Optional[
tf.Variable]) – tf.Variable that keeps track of the training steps.
Returns: - generator (
-
_build_and_restore_models(dataset)[source]¶ Build and restore a Subclassed model by firstly calling it on some data.
-
call(dataset, log_freq=10, measure_performance_freq=10)[source]¶ Perform the adversarial training.
Parameters: - dataset (
tf.data.Dataset) – The adversarial training dataset. - log_freq (int) – Specifies how many steps to run before logging the losses, e.g. log_frequency=10 logs every 10 steps of training. Pass log_frequency<=0 in case you don’t want to log.
- measure_performance_freq (int) –
Specifies how many steps to run before measuring the performance, e.g. measure_performance_freq=10 measures performance every 10 steps of training. Pass measure_performance_freq<=0 in case you don’t want to measure
performance.
- dataset (
-
train_step(real_xy, g_inputs)[source]¶ Train step for the AdversarialTrainer.
Parameters: - real_xy – input batch as extracted from the input dataset. (features, label) pair.
- g_inputs – batch of generator_input as generated from the input dataset.
Returns: d_loss, g_loss, fake –
- discriminator, generator loss values. fake is the
generator output.
-