gan¶
Collection of GANs trainers.
Classes
Primitive Trainer for GANs subclassed from |
|
Primitive Trainer for GANs using an Encoder sub-network. |
-
class
ashpy.trainers.gan.
AdversarialTrainer
(generator, discriminator, generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]¶ Bases:
ashpy.trainers.base_trainer.BaseTrainer
Primitive Trainer for GANs subclassed from
ashpy.trainers.BaseTrainer
.Examples
import shutil import operator from ashpy.models.gans import ConvGenerator, ConvDiscriminator from ashpy.metrics import InceptionScore from ashpy.losses.gan import DiscriminatorMinMax, GeneratorBCE generator = ConvGenerator( layer_spec_input_res=(7, 7), layer_spec_target_res=(28, 28), kernel_size=(5, 5), initial_filters=32, filters_cap=16, channels=1, ) discriminator = ConvDiscriminator( layer_spec_input_res=(28, 28), layer_spec_target_res=(7, 7), kernel_size=(5, 5), initial_filters=16, filters_cap=32, output_shape=1, ) # Losses generator_bce = GeneratorBCE() minmax = DiscriminatorMinMax() # Real data batch_size = 2 mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,)) # Trainer epochs = 2 logdir = "testlog/adversarial" metrics = [ InceptionScore( # Fake inception model ConvDiscriminator( layer_spec_input_res=(299, 299), layer_spec_target_res=(7, 7), kernel_size=(5, 5), initial_filters=16, filters_cap=32, output_shape=10, ), model_selection_operator=operator.gt, logdir=logdir, ) ] trainer = AdversarialTrainer( generator, discriminator, tf.optimizers.Adam(1e-4), tf.optimizers.Adam(1e-4), generator_bce, minmax, epochs, metrics, logdir, ) # take only 2 samples to speed up tests real_data = ( tf.data.Dataset.from_tensor_slices( (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))).take(batch_size) .batch(batch_size) .prefetch(1) ) # Add noise in the same dataset, just by mapping. # The return type of the dataset must be: tuple(tuple(a,b), noise) dataset = real_data.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100))) ) trainer(dataset) shutil.rmtree(logdir)
Initializing checkpoint. [1] Saved checkpoint: testlog/adversarial/ckpts/ckpt-1 Epoch 1 completed. [2] Saved checkpoint: testlog/adversarial/ckpts/ckpt-2 Epoch 2 completed.
-
__init__
(generator, discriminator, generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]¶ Instantiate a
AdversarialTrainer
.- Parameters
generator (
tf.keras.Model
) – Atf.keras.Model
describing the Generator part of a GAN.discriminator (
tf.keras.Model
) – Atf.keras.Model
describing the Discriminator part of a GAN.generator_optimizer (
tf.optimizers.Optimizer
) – Atf.optimizers
to use for the Generator.discriminator_optimizer (
tf.optimizers.Optimizer
) – Atf.optimizers
to use for the Discriminator.generator_loss (
ashpy.losses.executor.Executor
) – A ash Executor to compute the loss of the Generator.discriminator_loss (
ashpy.losses.executor.Executor
) – A ash Executor to compute the loss of the Discriminator.epochs (int) – number of training epochs.
metrics – (List): list of
tf.metrics
to measure on training and validation datalogdir – checkpoint and log directory.
post_process_callback (
callable
) – the function to postprocess the model output, if neededlog_eval_mode – models’ mode to use when evaluating and logging.
global_step – tf.Variable that keeps track of the training steps.
- Returns
-
_measure_performance
(dataset)[source]¶ Measure performance on dataset.
- Parameters
dataset (
tf.data.Dataset
) –
-
call
(dataset)[source]¶ Perform the adversarial training.
- Parameters
dataset (
tf.data.Dataset
) – The adversarial training dataset.
-
train_step
(real_xy, g_inputs)[source]¶ Train step for the AdversarialTrainer.
- Parameters
real_xy – input batch as extracted from the input dataset. (features, label) pair.
g_inputs – batch of generator_input as generated from the input dataset.
- Returns
d_loss, g_loss, fake –
- discriminator, generator loss values. fake is the
generator output.
-
-
class
ashpy.trainers.gan.
EncoderTrainer
(generator, discriminator, encoder, generator_optimizer, discriminator_optimizer, encoder_optimizer, generator_loss, discriminator_loss, encoder_loss, epochs, metrics, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]¶ Bases:
ashpy.trainers.gan.AdversarialTrainer
Primitive Trainer for GANs using an Encoder sub-network. The implementation is thought to be used with the BCE losses. To use another loss function consider subclassing the model and overriding the train_step method.
Examples
import shutil import operator from ashpy.metrics import EncodingAccuracy from ashpy.losses.gan import DiscriminatorMinMax, GeneratorBCE, EncoderBCE def real_gen(): label = 0 for _ in tf.range(100): yield ((10.0,), (label,)) latent_dim = 100 generator = tf.keras.Sequential([tf.keras.layers.Dense(1)]) left_input = tf.keras.layers.Input(shape=(1,)) left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input) right_input = tf.keras.layers.Input(shape=(latent_dim,)) right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input) net = tf.keras.layers.Concatenate()([left, right]) out = tf.keras.layers.Dense(1)(net) discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out]) encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)]) # Losses generator_bce = GeneratorBCE() encoder_bce = EncoderBCE() minmax = DiscriminatorMinMax() epochs = 2 # Fake pre-trained classifier num_classes = 1 classifier = tf.keras.Sequential( [tf.keras.layers.Dense(10), tf.keras.layers.Dense(num_classes)] ) logdir = "testlog/adversarial/encoder" metrics = [ EncodingAccuracy( classifier, # model_selection_operator=operator.gt, logdir=logdir ) ] trainer = EncoderTrainer( generator=generator, discriminator=discriminator, encoder=encoder, discriminator_optimizer=tf.optimizers.Adam(1e-4), generator_optimizer=tf.optimizers.Adam(1e-5), encoder_optimizer=tf.optimizers.Adam(1e-6), generator_loss=generator_bce, discriminator_loss=minmax, encoder_loss=encoder_bce, epochs=epochs, metrics=metrics, logdir=logdir, ) batch_size = 10 discriminator_input = tf.data.Dataset.from_generator( real_gen, (tf.float32, tf.int64), ((1), (1)) ).batch(batch_size) dataset = discriminator_input.map( lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim))) ) trainer(dataset) shutil.rmtree(logdir)
Initializing checkpoint. [10] Saved checkpoint: testlog/adversarial/encoder/ckpts/ckpt-1 Epoch 1 completed. [20] Saved checkpoint: testlog/adversarial/encoder/ckpts/ckpt-2 Epoch 2 completed.
-
__init__
(generator, discriminator, encoder, generator_optimizer, discriminator_optimizer, encoder_optimizer, generator_loss, discriminator_loss, encoder_loss, epochs, metrics, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', post_process_callback=None, log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>)[source]¶ Instantiate a
EncoderTrainer
.- Parameters
generator (
tf.keras.Model
) – Atf.keras.Model
describing the Generator part of a GAN.discriminator (
tf.keras.Model
) – Atf.keras.Model
describing the Discriminator part of a GAN.encoder (
tf.keras.Model
) – Atf.keras.Model
describing the Encoder part of a GAN.generator_optimizer (
tf.optimizers.Optimizer
) – Atf.optimizers
to use for the Generator.discriminator_optimizer (
tf.optimizers.Optimizer
) – Atf.optimizers
to use for the Discriminator.encoder_optimizer (
tf.optimizers.Optimizer
) – Atf.optimizers
to use for the Encoder.generator_loss (
ashpy.losses.executor.Executor
) – A ash Executor to compute the loss of the Generator.discriminator_loss (
ashpy.losses.executor.Executor
) – A ash Executor to compute the loss of the Discriminator.encoder_loss (
ashpy.losses.executor.Executor
) – A ash Executor to compute the loss of the Discriminator.epochs (int) – number of training epochs.
metrics – (List): list of tf.metrics to measure on training and validation data.
logdir – checkpoint and log directory.
post_process_callback (
callable
) – a function to post-process the output.log_eval_mode – models’ mode to use when evaluating and logging.
global_step – tf.Variable that keeps track of the training steps.
-
_measure_performance
(dataset)[source]¶ Measure performance on dataset.
- Parameters
dataset (
tf.data.Dataset
) –
-
call
(dataset)[source]¶ Perform the adversarial training.
- Parameters
dataset (
tf.data.Dataset
) – The adversarial training dataset.
-
train_step
(real_xy, g_inputs)[source]¶ Adversarial training step.
- Parameters
real_xy – input batch as extracted from the discriminator input dataset. (features, label) pair
g_inputs – batch of noise as generated by the generator input dataset.
- Returns
d_loss, g_loss, e_loss – discriminator, generator, encoder loss values.
-