ClassifierTrainer

Inheritance Diagram

Inheritance diagram of ashpy.trainers.classifier.ClassifierTrainer

class ashpy.trainers.classifier.ClassifierTrainer(model, optimizer, loss, epochs, metrics=None, callbacks=None, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log'), global_step=None)[source]

Bases: ashpy.trainers.trainer.Trainer

ClassifierTrainer provide the standard training loop for a classifier.

Methods

__init__(model, optimizer, loss, epochs[, …]) Instantiate the ClassifierTrainer trainer.
call(training_set, validation_set[, …]) Start the training.
train_step(features, labels) Train step.

Attributes

ckpt_id_callbacks
ckpt_id_global_step
ckpt_id_model
ckpt_id_optimizer
ckpt_id_steps_per_epoch
context Return the training context.
__init__(model, optimizer, loss, epochs, metrics=None, callbacks=None, logdir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/master/docs/source/log'), global_step=None)[source]

Instantiate the ClassifierTrainer trainer.

Parameters:

Examples

def toy_dataset():
    inputs = tf.expand_dims(tf.range(1, 1000.0), -1)
    labels = tf.expand_dims(
        [1 if tf.equal(tf.math.mod(tf.squeeze(i), 2), 0) else 0 for i in inputs], -1
    )
    return tf.data.Dataset.from_tensor_slices((inputs, labels)).shuffle(10).batch(2)


model = tf.keras.Sequential(
    [tf.keras.layers.Dense(10, activation=tf.nn.sigmoid), tf.keras.layers.Dense(2)]
)
optimizer = tf.optimizers.Adam(1e-3)

loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True))
logdir = "testlog"
epochs = 2

if pathlib.Path(logdir).exists():
    shutil.rmtree(logdir)

metrics = [
    ClassifierMetric(tf.metrics.Accuracy()),
    ClassifierMetric(tf.metrics.BinaryAccuracy()),
]

trainer = ClassifierTrainer(model=model,
                            optimizer=optimizer,
                            loss=loss,
                            epochs=epochs,
                            metrics=metrics,
                            logdir=logdir)
train, validation = toy_dataset(), toy_dataset()
trainer(train, validation)

shutil.rmtree(logdir)
Initializing checkpoint.
Starting epoch 1.
[500] Saved checkpoint: testlog/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[1000] Saved checkpoint: testlog/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
_build_and_restore_models(dataset)[source]

Build and restore a Subclassed model by firstly calling it on some data.

_train_step[source]

Perform the training step using the distribution strategy.

call(training_set, validation_set, log_freq=10, measure_performance_freq=10)[source]

Start the training.

Parameters:
  • training_set (tf.data.Dataset) – Training dataset.
  • validation_set (tf.data.Dataset) – Validation dataset.
  • log_freq (int) – Specifies how many steps to run before logging the losses, e.g. log_frequency=10 logs every 10 steps of training. Pass log_frequency<=0 in case you don’t want to log.
  • measure_performance_freq (int) – Specifies how many steps to run before measuring the performance, e.g. measure_performance_freq=10 measures performance every 10 steps of training. Pass measure_performance_freq<=0 in case you don’t want to measure performance.
train_step(features, labels)[source]

Train step.

Parameters:
  • features – Input features.
  • labels – The labels.
Returns:

Loss value.