GANContext

Inheritance Diagram

Inheritance diagram of ashpy.contexts.gan.GANContext

class ashpy.contexts.gan.GANContext(dataset=None, generator_model=None, discriminator_model=None, generator_loss=None, discriminator_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]

Bases: ashpy.contexts.context.Context

ashpy.contexts.gan.GANContext measure the specified metrics on the GAN.

Methods

__init__([dataset, generator_model, …]) Initialize the Context.

Attributes

current_batch Return the current batch.
dataset Retrieve the dataset.
discriminator_loss Retrieve the discriminator loss.
discriminator_model Retrieve the discriminator model.
exception Return the exception.
fake_samples Retrieve the fake samples, i.e.
generator_inputs Retrieve the generator inputs.
generator_loss Retrieve the generator loss.
generator_model Retrieve the generator model.
global_step Retrieve the global_step.
log_eval_mode Retrieve model(s) mode.
metrics Retrieve the metrics.
__init__(dataset=None, generator_model=None, discriminator_model=None, generator_loss=None, discriminator_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]

Initialize the Context.

Parameters:
  • dataset (tf.data.Dataset) – Dataset of tuples. [0] true dataset, [1] generator input dataset.
  • generator_model (tf.keras.Model) – The generator.
  • discriminator_model (tf.keras.Model) – The discriminator.
  • generator_loss (ashpy.losses.Executor()) – The generator loss.
  • discriminator_loss (ashpy.losses.Executor()) – The discriminator loss.
  • metrics (list of [ashpy.metrics.metric.Metric]) – All the metrics to be used to evaluate the model.
  • log_eval_mode (ashpy.modes.LogEvalMode) – Models’ mode to use when evaluating and logging.
  • global_step (tf.Variable) – tf.Variable that keeps track of the training steps.
  • checkpoint (tf.train.Checkpoint) – checkpoint to use to keep track of models status.
Return type:

None

discriminator_loss

Retrieve the discriminator loss.

Return type:Optional[Executor]
discriminator_model

Retrieve the discriminator model.

Return type:Model
Returns:tf.keras.Model.
fake_samples

Retrieve the fake samples, i.e. output of the generator.

Return type:Optional[Tensor]
generator_inputs

Retrieve the generator inputs.

Return type:Optional[Tensor]
generator_loss

Retrieve the generator loss.

Return type:Optional[Executor]
generator_model

Retrieve the generator model.

Return type:Model
Returns:tf.keras.Model.