trainer¶
Primitive Trainer Interface.
Classes
Trainer |
Trainer provide an interface for all trainers to inherit from. |
-
class
ashpy.trainers.trainer.
Trainer
(epochs, example_dim, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log'), log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None, metrics=None, callbacks=None)[source]¶ Bases:
abc.ABC
Trainer
provide an interface for all trainers to inherit from.-
__init__
(epochs, example_dim, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log'), log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None, metrics=None, callbacks=None)[source]¶ Primitive trainer interface. Handles model saving and restore.
Parameters: - epochs (int) – Number of training epochs.
- example_dim (Tuple[int, int]) – Dimension of an example. In the case of GANs the example has dimension (2,1) since it’s composed by a tuple in which the first element is a tuple with 2 components and the second component is a single element. In the case of classifier the example has dimension (1, 1) since it’s composed by the example and the label.
- logdir (str) – Checkpoint and log directory.
- log_eval_mode (py:class:ashpy.modes.LogEvalMode) – to use when evaluating and logging.
- global_step (Optional[py:class:ashpy.modes.LogEvalMode]) – tf.Variable that keeps track of the training steps.
- metrics (Optional[List | Tuple [
ashpy.metrics.Metric
]]) – list or tuple of metrics. - callbacks (Optional[List[
ashpy.callbacks.Callback
]]) – list or tuple of callbacks to handle events.
Return type: None
-
_build_and_restore_models
(dataset)[source]¶ Build and restore a Subclassed model by firstly calling it on some data.
-
_current_epoch
()[source]¶ Get the current epoch using the (restored) variables.
Return type: Tensor
Returns: current_epoch ( tf.Tensor
) – the current epoch of training.
-
_dataset_from_example
(example, dims)[source]¶ Get a dataset from a given example.
Return type: DatasetV2
Returns: The dataset containing only the example.
-
_generate_checkpoint_map
()[source]¶ Generate a human readable map of the id and type mapping in the checkpoint.
-
_measure_performance_if_needed
(example, measure_performance_freq)[source]¶ Measure performance if needed.
Measure performance if self._global_step % measure_performance_freq is 0.
-
_reduce
(per_replica_tensor, reduce_op)[source]¶ Reduce the input tensor in a distributed fashion, using the specified op.
-
_update_global_batch_size
(dataset, executors=None)[source]¶ Set the self._global_batch_size variable where needed.
Parameters: dataset ( tf.data.Dataset
) – a dataset from which the batch size will be extracted.- :param executors (Union[List[
ashpy.losses.executor.Executor
],:ashpy.losses.executor.Executor
]: a list of executor - with the property “global_batch_size”.
- :param executors (Union[List[
-