Trainer

Inheritance Diagram

Inheritance diagram of ashpy.trainers.trainer.Trainer

class ashpy.trainers.trainer.Trainer(epochs, example_dim, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None, metrics=None, callbacks=None)[source]

Bases: abc.ABC

Trainer provide an interface for all trainers to inherit from.

Methods

__init__(epochs, example_dim[, logdir, …])

Primitive trainer interface.

call(*args, **kwargs)

Execute the training process.

local_example(example, dims)

Return a local example from a distributed example.

measure_metrics()

Measure the metrics.

model_selection()

Use the metrics to perform model selection.

Attributes

context

Return the training context.

__init__(epochs, example_dim, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/latest/docs/source/log', log_eval_mode=<LogEvalMode.TEST: 1>, global_step=None, metrics=None, callbacks=None)[source]

Primitive trainer interface. Handles model saving and restore.

Parameters
  • epochs (int) – Number of training epochs.

  • example_dim (Tuple[int, int]) – Dimension of an example. In the case of GANs the example has dimension (2,1) since it’s composed by a tuple in which the first element is a tuple with 2 components and the second component is a single element. In the case of classifier the example has dimension (1, 1) since it’s composed by the example and the label.

  • logdir (str) – Checkpoint and log directory.

  • log_eval_mode (py:class:ashpy.modes.LogEvalMode) – to use when evaluating and logging.

  • global_step (Optional[py:class:ashpy.modes.LogEvalMode]) – tf.Variable that keeps track of the training steps.

  • metrics (Optional[List[ashpy.metrics.Metric]]) – list of metrics.

  • callbacks (Optional[List[ashpy.callbacks.Callback]]) – list of callbacks to handle events.

_current_epoch()[source]

Get the current epoch using the (restored) variables.

Return type

Tensor

Returns

current_epoch (tf.Tensor) – the current epoch of training.

_dataset_from_example(example, dims)[source]

Get a dataset from a given example.

Return type

DatasetV2

Returns

The dataset containing only the example.

_log_metrics_and_reset()[source]

Call for each metric the log and reset_states.

_measure_performance()[source]

Measure performance on dataset.

_measure_performance_if_needed(example, measure_performance_freq)[source]

Measure performance if needed.

Measure performance if self._global_step % measure_performance_freq is 0.

_on_batch_end()[source]

Handle the end of a training batch.

Return type

None

_on_batch_start()[source]

Handle the start of a training batch.

Return type

None

_on_epoch_end()[source]

Handle the end of the training epoch.

Return type

None

_on_epoch_start()[source]

Handle the start of the training epoch.

Return type

None

_on_exception()[source]

Handle the exception.

Return type

None

_on_train_end()[source]

Handle the end of training.

Return type

None

_on_train_start()[source]

Handle the start of training.

Return type

None

_reduce(per_replica_tensor, reduce_op)[source]

Reduce the input tensor in a distributed fashion, using the specified op.

_restore_or_init()[source]

Restore or initialize the persistence layer (checkpoint).

_save()[source]

Save the current checkpointable object status.

_update_global_batch_size(dataset, executors=None)[source]

Set the self._global_batch_size variable where needed.

Parameters

dataset (tf.data.Dataset) – a dataset from which the batch size will be extracted.

:param executors (Union[List[ashpy.losses.executor.Executor],: ashpy.losses.executor.Executor]: a list of executor

with the property “global_batch_size”.

_validate_callbacks()[source]

Check if every callback is an ashpy.callbacks.Callback.

_validate_metrics()[source]

Check if every metric is an ashpy.metrics.Metric.

abstract call(*args, **kwargs)[source]

Execute the training process.

property context

Return the training context.

Return type

Context

local_example(example, dims)[source]

Return a local example from a distributed example.

Returns

A local example from a distributed example.

measure_metrics()[source]

Measure the metrics.

Return type

None

model_selection()[source]

Use the metrics to perform model selection.

Return type

None