base_trainer¶
Primitive Trainer Interface.
Classes
|
-
class
ashpy.trainers.base_trainer.
BaseTrainer
(epochs, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]¶ Bases:
abc.ABC
BaseTrainer
provide an interface for all trainers to inherit from.-
__init__
(epochs, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 0>, global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]¶ Primitive trainer interface. Handles model saving and restore.
- Parameters
-
_current_epoch
()[source]¶ Get the current epoch using the (restored) variables.
- Returns
current_epoch (int)
-
_epoch_completed
(epoch)[source]¶ Handle the end of the training epoch.
- Parameters
epoch (int) – the just completed training epoch.
-
_log
(name, out)[source]¶ Log the out tensor using name as its name in tensorboard.
- Parameters
name – summary name.
out – the tensor to log.
-
_reduce
(per_replica_tensor, reduce_op)[source]¶ Given the input tensor, reduces it in a distributed fashion, using the specified op.
-
_update_global_batch_size
(dataset, executors=None)[source]¶ Given a dataset and the current distribution strategy sets the self._global_batch_size variable where needed. :param dataset: a dataset from wich the batch size will be extracted. :param executors: a list of executor with the property “global_batch_size” settable.
-
abstract
call
(dataset)[source]¶ Execute the training process.
Iterate over the elements of a
tf.data.Dataset
. The dataset must contain everything needed to train the model.- Parameters
dataset (
DatasetV2
) – Atf.data.Dataset
to loop on to train the model.
-