GANEncoderContext¶
Inheritance Diagram
-
class
ashpy.contexts.gan.
GANEncoderContext
(dataset=None, generator_model=None, discriminator_model=None, encoder_model=None, generator_loss=None, discriminator_loss=None, encoder_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Bases:
ashpy.contexts.gan.GANContext
ashpy.contexts.gan.GANEncoderContext
measure the specified metrics on the GAN.Methods
__init__
([dataset, generator_model, …])Initialize the Context. Attributes
current_batch
Return the current batch. dataset
Retrieve the dataset. discriminator_loss
Retrieve the discriminator loss. discriminator_model
Retrieve the discriminator model. encoder_inputs
Retrieve the inputs of the encoder. encoder_loss
Retrieve the encoder loss. encoder_model
Retrieve the encoder model. exception
Return the exception. fake_samples
Retrieve the fake samples, i.e. generator_inputs
Retrieve the generator inputs. generator_loss
Retrieve the generator loss. generator_model
Retrieve the generator model. generator_of_encoder
Retrieve the images generated from the encoder output. global_step
Retrieve the global_step. log_eval_mode
Retrieve model(s) mode. metrics
Retrieve the metrics. -
__init__
(dataset=None, generator_model=None, discriminator_model=None, encoder_model=None, generator_loss=None, discriminator_loss=None, encoder_loss=None, metrics=None, log_eval_mode=<LogEvalMode.TRAIN: 2>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, checkpoint=None)[source]¶ Initialize the Context.
Parameters: - dataset (
tf.data.Dataset
) – Dataset of tuples. [0] true dataset, [1] generator input dataset. - generator_model (
tf.keras.Model
) – The generator. - discriminator_model (
tf.keras.Model
) – The discriminator. - encoder_model (
tf.keras.Model
) – The encoder. - generator_loss (
ashpy.losses.Executor()
) – The generator loss. - discriminator_loss (
ashpy.losses.Executor()
) – The discriminator loss. - encoder_loss (
ashpy.losses.Executor()
) – The encoder loss. - metrics (
list
of [ashpy.metrics.metric.Metric
]) – All the metrics to be used to evaluate the model. - log_eval_mode (
ashpy.modes.LogEvalMode
) – Models’ mode to use when evaluating and logging. - global_step (
tf.Variable
) – tf.Variable that keeps track of the training steps. - checkpoint (
tf.train.Checkpoint
) – checkpoint to use to keep track of models status.
Return type: - dataset (
-
encoder_inputs
¶ Retrieve the inputs of the encoder.
Return type: Tensor
-
encoder_model
¶ Retrieve the encoder model.
Return type: Model
Returns: tf.keras.Model
.
-
generator_of_encoder
¶ Retrieve the images generated from the encoder output.
Return type: Tensor
-