ClassifierTrainer

Inheritance Diagram

Inheritance diagram of ashpy.trainers.classifier.ClassifierTrainer

class ashpy.trainers.classifier.ClassifierTrainer(model, optimizer, loss, epochs, metrics=None, callbacks=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.2.0/docs/source/log', global_step=None)[source]

Bases: ashpy.trainers.trainer.Trainer

ClassifierTrainer provide the standard training loop for a classifier.

Methods

__init__(model, optimizer, loss, epochs[, …])

Instantiate the ClassifierTrainer trainer.

call(training_set, validation_set[, …])

Start the training.

train_step(features, labels)

Train step.

Attributes

context

Return the training context.

__init__(model, optimizer, loss, epochs, metrics=None, callbacks=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.2.0/docs/source/log', global_step=None)[source]

Instantiate the ClassifierTrainer trainer.

Parameters

Examples

def toy_dataset():
    inputs = tf.expand_dims(tf.range(1, 1000.0), -1)
    labels = tf.expand_dims(
        [1 if tf.equal(tf.math.mod(tf.squeeze(i), 2), 0) else 0 for i in inputs], -1
    )
    return tf.data.Dataset.from_tensor_slices((inputs, labels)).shuffle(10).batch(2)


model = tf.keras.Sequential(
    [tf.keras.layers.Dense(10, activation=tf.nn.sigmoid), tf.keras.layers.Dense(2)]
)
optimizer = tf.optimizers.Adam(1e-3)

loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True))
logdir = "testlog"
epochs = 2

metrics = [
    ClassifierMetric(tf.metrics.Accuracy()),
    ClassifierMetric(tf.metrics.BinaryAccuracy()),
]

trainer = ClassifierTrainer(model=model,
                            optimizer=optimizer,
                            loss=loss,
                            epochs=epochs,
                            metrics=metrics,
                            logdir=logdir)
train, validation = toy_dataset(), toy_dataset()
trainer(train, validation)
shutil.rmtree(logdir)
Initializing checkpoint.
Starting epoch 1.
[500] Saved checkpoint: testlog/ckpts/ckpt-1
Epoch 1 completed.
Starting epoch 2.
[1000] Saved checkpoint: testlog/ckpts/ckpt-2
Epoch 2 completed.
Training finished after 2 epochs.
_train_step[source]

Perform the training step using the distribution strategy.

call(training_set, validation_set, log_freq=10, measure_performance_freq=10)[source]

Start the training.

Parameters
  • training_set (tf.data.Dataset) – Training dataset.

  • validation_set (tf.data.Dataset) – Validation dataset.

  • log_freq (int) – Specifies how many steps to run before logging the losses, e.g. log_frequency=10 logs every 10 steps of training. Pass log_frequency<=0 in case you don’t want to log.

  • measure_performance_freq (int) – Specifies how many steps to run before measuring the performance, e.g. measure_performance_freq=10 measures performance every 10 steps of training. Pass measure_performance_freq<=0 in case you don’t want to measure performance.

train_step(features, labels)[source]

Train step.

Parameters
  • features – Input features.

  • labels – The labels.

Returns

Loss value.