BaseTrainer

Inheritance Diagram

Inheritance diagram of ashpy.trainers.base_trainer.BaseTrainer

class ashpy.trainers.base_trainer.BaseTrainer(epochs, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]

Bases: abc.ABC

BaseTrainer provide an interface for all trainers to inherit from.

Methods

__init__(epochs[, logdir, log_eval_mode, …])

Primitive trainer interface.

call(dataset)

Execute the training process.

__init__(epochs, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]

Primitive trainer interface. Handles model saving and restore.

Parameters
  • epochs (int) – Number of training epochs.

  • logdir (str) – Checkpoint and log directory.

  • log_eval_mode – models’ mode to use when evaluating and logging.

  • global_step – tf.Variable that keeps track of the training steps.

  • post_process_callback – the function to postprocess the model output, if needed.

_current_epoch()[source]

Get the current epoch using the (restored) variables.

Returns

current_epoch (int)

_epoch_completed(epoch)[source]

Handle the end of the training epoch.

Parameters

epoch (int) – the just completed training epoch.

_log(name, out)[source]

Log the out tensor using name as its name in tensorboard.

Parameters
  • name – summary name.

  • out – the tensor to log.

_reduce(per_replica_tensor, reduce_op)[source]

Given the input tensor, reduces it in a distributed fashion, using the specified op.

_restore_or_init()[source]

Restores or initializes the persistence layer (checkpoint).

_save()[source]

Save the current checkpointable object status.

_update_global_batch_size(dataset, executors=None)[source]

Given a dataset and the current distribution strategy sets the self._global_batch_size variable where needed. :param dataset: a dataset from wich the batch size will be extracted. :param executors: a list of executor with the property “global_batch_size” settable.

abstract call(dataset)[source]

Execute the training process.

Iterate over the elements of a tf.data.Dataset. The dataset must contain everything needed to train the model.

Parameters

dataset (DatasetV2) – A tf.data.Dataset to loop on to train the model.