GANContext¶
Inheritance Diagram

-
class
ashpy.contexts.gan.GANContext(dataset=None, generator_model=None, discriminator_model=None, generator_loss=None, discriminator_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Bases:
ashpy.contexts.context.Contextashpy.contexts.gan.GANContextmeasure the specified metrics on the GAN.Methods
__init__([dataset, generator_model, …])Initialize the Context. Attributes
current_batchReturn the current batch. datasetRetrieve the dataset. discriminator_lossRetrieve the discriminator loss. discriminator_modelRetrieve the discriminator model. exceptionReturn the exception. fake_samplesRetrieve the fake samples, i.e. generator_inputsRetrieve the generator inputs. generator_lossRetrieve the generator loss. generator_modelRetrieve the generator model. global_stepRetrieve the global_step. log_eval_modeRetrieve model(s) mode. metricsRetrieve the metrics. -
__init__(dataset=None, generator_model=None, discriminator_model=None, generator_loss=None, discriminator_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Initialize the Context.
Parameters: - dataset (
tf.data.Dataset) – Dataset of tuples. [0] true dataset, [1] generator input dataset. - generator_model (
tf.keras.Model) – The generator. - discriminator_model (
tf.keras.Model) – The discriminator. - generator_loss (
ashpy.losses.Executor()) – The generator loss. - discriminator_loss (
ashpy.losses.Executor()) – The discriminator loss. - metrics (
listof [ashpy.metrics.metric.Metric]) – All the metrics to be used to evaluate the model. - log_eval_mode (
ashpy.modes.LogEvalMode) – Models’ mode to use when evaluating and logging. - global_step (
tf.Variable) – tf.Variable that keeps track of the training steps. - checkpoint (
tf.train.Checkpoint) – checkpoint to use to keep track of models status.
Return type: - dataset (
-
discriminator_model¶ Retrieve the discriminator model.
Return type: ModelReturns: tf.keras.Model.
-
generator_model¶ Retrieve the generator model.
Return type: ModelReturns: tf.keras.Model.
-