EncoderTrainer

Inheritance Diagram

Inheritance diagram of ashpy.trainers.gan.EncoderTrainer

class ashpy.trainers.gan.EncoderTrainer(generator, discriminator, encoder, generator_optimizer, discriminator_optimizer, encoder_optimizer, generator_loss, discriminator_loss, encoder_loss, epochs, metrics=None, callbacks=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None)[source]

Bases: ashpy.trainers.gan.AdversarialTrainer

Primitive Trainer for GANs using an Encoder sub-network.

The implementation is thought to be used with the BCE losses. To use another loss function consider subclassing the model and overriding the train_step method.

Examples

import shutil
import operator

def real_gen():
    label = 0
    for _ in tf.range(100):
        yield ((10.0,), (label,))

latent_dim = 100

generator = tf.keras.Sequential([tf.keras.layers.Dense(1)])

left_input = tf.keras.layers.Input(shape=(1,))
left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input)

right_input = tf.keras.layers.Input(shape=(latent_dim,))
right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input)

net = tf.keras.layers.Concatenate()([left, right])
out = tf.keras.layers.Dense(1)(net)

discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out])

encoder = tf.keras.Sequential([tf.keras.layers.Dense(latent_dim)])

# Losses
generator_bce = losses.gan.GeneratorBCE()
encoder_bce = losses.gan.EncoderBCE()
minmax = losses.gan.DiscriminatorMinMax()

epochs = 2

# Fake pre-trained classifier
num_classes = 1
classifier = tf.keras.Sequential(
    [tf.keras.layers.Dense(10), tf.keras.layers.Dense(num_classes)]
)

logdir = "testlog/adversarial_encoder"

if os.path.exists(logdir):
    shutil.rmtree(logdir)

metrics = [
    metrics.gan.EncodingAccuracy(
        classifier,
        # model_selection_operator=operator.gt,
        logdir=logdir
    )
]

trainer = trainers.gan.EncoderTrainer(
    generator=generator,
    discriminator=discriminator,
    encoder=encoder,
    discriminator_optimizer=tf.optimizers.Adam(1e-4),
    generator_optimizer=tf.optimizers.Adam(1e-5),
    encoder_optimizer=tf.optimizers.Adam(1e-6),
    generator_loss=generator_bce,
    discriminator_loss=minmax,
    encoder_loss=encoder_bce,
    epochs=epochs,
    metrics=metrics,
    logdir=logdir,
)

batch_size = 10
discriminator_input = tf.data.Dataset.from_generator(
    real_gen, (tf.float32, tf.int64), ((1), (1))
).batch(batch_size)

dataset = discriminator_input.map(
    lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
)

trainer(dataset)

shutil.rmtree(logdir)
Initializing checkpoint.
Starting epoch 1.
[10] Saved checkpoint: testlog/adversarial_encoder/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[20] Saved checkpoint: testlog/adversarial_encoder/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.

Methods

__init__(generator, discriminator, encoder, …)

Instantiate a EncoderTrainer.

call(dataset[, log_freq, …])

Perform the adversarial training.

train_step(real_xy, g_inputs)

Adversarial training step.

Attributes

context

Return the training context.

__init__(generator, discriminator, encoder, generator_optimizer, discriminator_optimizer, encoder_optimizer, generator_loss, discriminator_loss, encoder_loss, epochs, metrics=None, callbacks=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None)[source]

Instantiate a EncoderTrainer.

Parameters
_train_step[source]

Perform the training step using the distribution strategy.

call(dataset, log_freq=10, measure_performance_freq=10)[source]

Perform the adversarial training.

Parameters
  • dataset (tf.data.Dataset) – The adversarial training dataset.

  • log_freq (int) – Specifies how many steps to run before logging the losses, e.g. log_frequency=10 logs every 10 steps of training. Pass log_frequency<=0 in case you don’t want to log.

  • measure_performance_freq (int) – Specifies how many steps to run before measuring the performance, e.g. measure_performance_freq=10 measures performance every 10 steps of training. Pass measure_performance_freq<=0 in case you don’t want to measure performance.

train_step(real_xy, g_inputs)[source]

Adversarial training step.

Parameters
  • real_xy – input batch as extracted from the discriminator input dataset. (features, label) pair

  • g_inputs – batch of noise as generated by the generator input dataset.

Returns

d_loss, g_loss, e_loss – discriminator, generator, encoder loss values.