GANEncoderContext

Inheritance Diagram

Inheritance diagram of ashpy.contexts.gan.GANEncoderContext

class ashpy.contexts.gan.GANEncoderContext(dataset=None, generator_model=None, discriminator_model=None, encoder_model=None, generator_loss=None, discriminator_loss=None, encoder_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]

Bases: ashpy.contexts.gan.GANContext

ashpy.contexts.gan.GANEncoderContext measure the specified metrics on the GAN.

Methods

__init__([dataset, generator_model, …]) Initialize the Context.

Attributes

current_batch Return the current batch.
dataset Retrieve the dataset.
discriminator_loss Retrieve the discriminator loss.
discriminator_model Retrieve the discriminator model.
encoder_inputs Retrieve the inputs of the encoder.
encoder_loss Retrieve the encoder loss.
encoder_model Retrieve the encoder model.
exception Return the exception.
fake_samples Retrieve the fake samples, i.e.
generator_inputs Retrieve the generator inputs.
generator_loss Retrieve the generator loss.
generator_model Retrieve the generator model.
generator_of_encoder Retrieve the images generated from the encoder output.
global_step Retrieve the global_step.
log_eval_mode Retrieve model(s) mode.
metrics Retrieve the metrics.
__init__(dataset=None, generator_model=None, discriminator_model=None, encoder_model=None, generator_loss=None, discriminator_loss=None, encoder_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]

Initialize the Context.

Parameters:
  • dataset (tf.data.Dataset) – Dataset of tuples. [0] true dataset, [1] generator input dataset.
  • generator_model (tf.keras.Model) – The generator.
  • discriminator_model (tf.keras.Model) – The discriminator.
  • encoder_model (tf.keras.Model) – The encoder.
  • generator_loss (ashpy.losses.Executor()) – The generator loss.
  • discriminator_loss (ashpy.losses.Executor()) – The discriminator loss.
  • encoder_loss (ashpy.losses.Executor()) – The encoder loss.
  • metrics (list of [ashpy.metrics.metric.Metric]) – All the metrics to be used to evaluate the model.
  • log_eval_mode (ashpy.modes.LogEvalMode) – Models’ mode to use when evaluating and logging.
  • global_step (tf.Variable) – tf.Variable that keeps track of the training steps.
  • checkpoint (tf.train.Checkpoint) – checkpoint to use to keep track of models status.
Return type:

None

encoder_inputs

Retrieve the inputs of the encoder.

Return type:Tensor
encoder_loss

Retrieve the encoder loss.

Return type:Optional[Executor]
encoder_model

Retrieve the encoder model.

Return type:Model
Returns:tf.keras.Model.
generator_of_encoder

Retrieve the images generated from the encoder output.

Return type:Tensor