
Primitive Trainer Interface.



BaseTrainer provide an interface for all trainers to inherit from.

class ashpy.trainers.base_trainer.BaseTrainer(epochs, logdir='/home/docs/checkouts/', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]

Bases: abc.ABC

BaseTrainer provide an interface for all trainers to inherit from.

__init__(epochs, logdir='/home/docs/checkouts/', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]

Primitive trainer interface. Handles model saving and restore.

  • epochs (int) – Number of training epochs.

  • logdir (str) – Checkpoint and log directory.

  • log_eval_mode – models’ mode to use when evaluating and logging.

  • global_step – tf.Variable that keeps track of the training steps.

  • post_process_callback – the function to postprocess the model output, if needed.


Get the current epoch using the (restored) variables.


current_epoch (int)


Handle the end of the training epoch.


epoch (int) – the just completed training epoch.

_log(name, out)[source]

Log the out tensor using name as its name in tensorboard.

  • name – summary name.

  • out – the tensor to log.

_reduce(per_replica_tensor, reduce_op)[source]

Given the input tensor, reduces it in a distributed fashion, using the specified op.


Restores or initializes the persistence layer (checkpoint).


Save the current checkpointable object status.

_update_global_batch_size(dataset, executors=None)[source]

Given a dataset and the current distribution strategy sets the self._global_batch_size variable where needed. :param dataset: a dataset from wich the batch size will be extracted. :param executors: a list of executor with the property “global_batch_size” settable.

abstract call(dataset)[source]

Execute the training process.

Iterate over the elements of a The dataset must contain everything needed to train the model.


dataset (DatasetV2) – A to loop on to train the model.