BaseTrainer¶
Inheritance Diagram

-
class
ashpy.trainers.BaseTrainer(epochs, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]¶ Bases:
abc.ABCBaseTrainerprovide an interface for all trainers to inherit from.Methods
__init__(epochs[, logdir, log_eval_mode, …])Primitive trainer interface.
call(dataset)Execute the training process.
-
__init__(epochs, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]¶ Primitive trainer interface. Handles model saving and restore.
- Parameters
-
_current_epoch()[source]¶ Get the current epoch using the (restored) variables.
- Returns
current_epoch (int)
-
_epoch_completed(epoch)[source]¶ Handle the end of the training epoch.
- Parameters
epoch (int) – the just completed training epoch.
-
_log(name, out)[source]¶ Log the out tensor using name as its name in tensorboard.
- Parameters
name – summary name.
out – the tensor to log.
-
_reduce(per_replica_tensor, reduce_op)[source]¶ Given the input tensor, reduces it in a distributed fashion, using the specified op.
-
_update_global_batch_size(dataset, executors=None)[source]¶ Given a dataset and the current distribution strategy sets the self._global_batch_size variable where needed. :param dataset: a dataset from wich the batch size will be extracted. :param executors: a list of executor with the property “global_batch_size” settable.
-
abstract
call(dataset)[source]¶ Execute the training process.
Iterate over the elements of a
tf.data.Dataset. The dataset must contain everything needed to train the model.- Parameters
dataset (
DatasetV2) – Atf.data.Datasetto loop on to train the model.
-