gan¶
GANContext measures the specified metrics on the GAN.
Classes
|
|
|
-
class
ashpy.contexts.gan.
GANContext
(dataset=None, generator_model=None, discriminator_model=None, generator_loss=None, discriminator_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Bases:
ashpy.contexts.context.Context
ashpy.contexts.gan.GANContext
measure the specified metrics on the GAN.-
__init__
(dataset=None, generator_model=None, discriminator_model=None, generator_loss=None, discriminator_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Initialize the Context.
- Parameters
dataset (
tf.data.Dataset
) – Dataset of tuples. [0] true dataset, [1] generator input dataset.generator_model (
tf.keras.Model
) – The generator.discriminator_model (
tf.keras.Model
) – The discriminator.generator_loss (
ashpy.losses.Executor()
) – The generator loss.discriminator_loss (
ashpy.losses.Executor()
) – The discriminator loss.metrics (
list
of [ashpy.metrics.metric.Metric
]) – All the metrics to be used to evaluate the model.log_eval_mode (
ashpy.modes.LogEvalMode
) – Models’ mode to use when evaluating and logging.global_step (
tf.Variable
) – tf.Variable that keeps track of the training steps.checkpoint (
tf.train.Checkpoint
) – checkpoint to use to keep track of models status.
- Return type
-
property
discriminator_model
¶ Retrieve the discriminator model.
- Return type
Model
- Returns
-
property
fake_samples
¶ Retrieve the fake samples, i.e. output of the generator.
- Return type
Optional
[Tensor
]
-
property
generator_model
¶ Retrieve the generator model.
- Return type
Model
- Returns
-
-
class
ashpy.contexts.gan.
GANEncoderContext
(dataset=None, generator_model=None, discriminator_model=None, encoder_model=None, generator_loss=None, discriminator_loss=None, encoder_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Bases:
ashpy.contexts.gan.GANContext
ashpy.contexts.gan.GANEncoderContext
measure the specified metrics on the GAN.-
__init__
(dataset=None, generator_model=None, discriminator_model=None, encoder_model=None, generator_loss=None, discriminator_loss=None, encoder_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Initialize the Context.
- Parameters
dataset (
tf.data.Dataset
) – Dataset of tuples. [0] true dataset, [1] generator input dataset.generator_model (
tf.keras.Model
) – The generator.discriminator_model (
tf.keras.Model
) – The discriminator.encoder_model (
tf.keras.Model
) – The encoder.generator_loss (
ashpy.losses.Executor()
) – The generator loss.discriminator_loss (
ashpy.losses.Executor()
) – The discriminator loss.encoder_loss (
ashpy.losses.Executor()
) – The encoder loss.metrics (
list
of [ashpy.metrics.metric.Metric
]) – All the metrics to be used to evaluate the model.log_eval_mode (
ashpy.modes.LogEvalMode
) – Models’ mode to use when evaluating and logging.global_step (
tf.Variable
) – tf.Variable that keeps track of the training steps.checkpoint (
tf.train.Checkpoint
) – checkpoint to use to keep track of models status.
- Return type
-
property
encoder_inputs
¶ Retrieve the inputs of the encoder.
- Return type
Tensor
-
property
encoder_model
¶ Retrieve the encoder model.
- Return type
Model
- Returns
-
property
generator_of_encoder
¶ Retrieve the images generated from the encoder output.
- Return type
Tensor
-