Trainer

Inheritance Diagram

Inheritance diagram of ashpy.trainers.trainer.Trainer

class ashpy.trainers.trainer.Trainer(epochs, example_dim, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log'), log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None, metrics=None, callbacks=None)[source]

Bases: abc.ABC

Trainer provide an interface for all trainers to inherit from.

Methods

__init__(epochs, example_dim[, logdir, …]) Primitive trainer interface.
call(*args, **kwargs) Execute the training process.
local_example(example, dims) Return a local example from a distributed example.
measure_metrics() Measure the metrics.
model_selection() Use the metrics to perform model selection.

Attributes

ckpt_id_callbacks
ckpt_id_global_step
ckpt_id_steps_per_epoch
context Return the training context.
__init__(epochs, example_dim, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log'), log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None, metrics=None, callbacks=None)[source]

Primitive trainer interface. Handles model saving and restore.

Parameters:
  • epochs (int) – Number of training epochs.
  • example_dim (Tuple[int, int]) – Dimension of an example. In the case of GANs the example has dimension (2,1) since it’s composed by a tuple in which the first element is a tuple with 2 components and the second component is a single element. In the case of classifier the example has dimension (1, 1) since it’s composed by the example and the label.
  • logdir (str) – Checkpoint and log directory.
  • log_eval_mode (py:class:ashpy.modes.LogEvalMode) – to use when evaluating and logging.
  • global_step (Optional[py:class:ashpy.modes.LogEvalMode]) – tf.Variable that keeps track of the training steps.
  • metrics (Optional[List | Tuple [ashpy.metrics.Metric]]) – list or tuple of metrics.
  • callbacks (Optional[List[ashpy.callbacks.Callback]]) – list or tuple of callbacks to handle events.
Return type:

None

_build_and_restore_models(dataset)[source]

Build and restore a Subclassed model by firstly calling it on some data.

static _check_name_collision(objects, obj_type)[source]

Check that all objects have unique name.

_current_epoch()[source]

Get the current epoch using the (restored) variables.

Return type:Tensor
Returns:current_epoch (tf.Tensor) – the current epoch of training.
_dataset_from_example(example, dims)[source]

Get a dataset from a given example.

Return type:DatasetV2
Returns:The dataset containing only the example.
_generate_checkpoint_map()[source]

Generate a human readable map of the id and type mapping in the checkpoint.

_log_metrics_and_reset()[source]

Call for each metric the log and reset_states.

_measure_performance()[source]

Measure performance on dataset.

_measure_performance_if_needed(example, measure_performance_freq)[source]

Measure performance if needed.

Measure performance if self._global_step % measure_performance_freq is 0.

_on_batch_end()[source]

Handle the end of a training batch.

Return type:None
_on_batch_start()[source]

Handle the start of a training batch.

Return type:None
_on_epoch_end()[source]

Handle the end of the training epoch.

Return type:None
_on_epoch_start()[source]

Handle the start of the training epoch.

Return type:None
_on_exception()[source]

Handle the exception.

Return type:None
_on_train_end()[source]

Handle the end of training.

Return type:None
_on_train_start()[source]

Handle the start of training.

Return type:None
_reduce(per_replica_tensor, reduce_op)[source]

Reduce the input tensor in a distributed fashion, using the specified op.

_restore_or_init()[source]

Restore or initialize the persistence layer (checkpoint).

_save()[source]

Save the current checkpointable object status.

_update_checkpoint(ckpt_dict)[source]

Update the checkpoint with the new checkpoint dictionary.

_update_global_batch_size(dataset, executors=None)[source]

Set the self._global_batch_size variable where needed.

Parameters:dataset (tf.data.Dataset) – a dataset from which the batch size will be extracted.
:param executors (Union[List[ashpy.losses.executor.Executor],: ashpy.losses.executor.Executor]: a list of executor
with the property “global_batch_size”.
_validate_callbacks()[source]

Check if every callback is an ashpy.callbacks.Callback.

_validate_metrics()[source]

Check if every metric is an ashpy.metrics.Metric.

call(*args, **kwargs)[source]

Execute the training process.

context

Return the training context.

Return type:Context
local_example(example, dims)[source]

Return a local example from a distributed example.

Returns:A local example from a distributed example.
measure_metrics()[source]

Measure the metrics.

Return type:None
model_selection()[source]

Use the metrics to perform model selection.

Return type:None