classifier

Primitive Trainer Interface.

Classes

ClassifierTrainer

ClassifierTrainer provide the standard training loop for a classifier.

class ashpy.trainers.classifier.ClassifierTrainer(model, optimizer, loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]

Bases: ashpy.trainers.base_trainer.BaseTrainer

ClassifierTrainer provide the standard training loop for a classifier.

__init__(model, optimizer, loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]

Instantiate the ClassifierTrainer trainer.

Parameters
  • model (tf.keras.Model) – A tf.keras.Model model.

  • optimizer (tf.optimizers.Optimizer) – A tf.optimizers.Optimizer.

  • loss (callable) – A loss function built following tf.losses.

  • epochs (int) – Number of training epochs.

  • metrics – (List): List of python objects (dictionaries or tf.metrics objects) to measure on training and validation data.

  • logdir (str) – Checkpoint and log directory.

  • global_step – tf.Variable that keeps track of the training steps.

  • post_process_callback (callable) – the function to postprocess the model output, if needed.

Examples

def toy_dataset():
    inputs = tf.expand_dims(tf.range(1, 1000.0), -1)
    labels = tf.expand_dims(
        [1 if tf.equal(tf.math.mod(tf.squeeze(i), 2), 0) else 0 for i in inputs], -1
    )
    return tf.data.Dataset.from_tensor_slices((inputs, labels)).shuffle(10).batch(2)


model = tf.keras.Sequential(
    [tf.keras.layers.Dense(10, activation=tf.nn.sigmoid), tf.keras.layers.Dense(2)]
)
optimizer = tf.optimizers.Adam(1e-3)

loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True))
logdir = "testlog"
epochs = 2

metrics = [
    ClassifierMetric(tf.metrics.Accuracy()),
    ClassifierMetric(tf.metrics.BinaryAccuracy()),
]

trainer = ClassifierTrainer(model, optimizer, loss, epochs, metrics, logdir=logdir)
train, validation = toy_dataset(), toy_dataset()
trainer(train, validation)
shutil.rmtree(logdir)
Initializing checkpoint.
[500] Saved checkpoint: testlog/ckpts/ckpt-1
Epoch 1 completed.
[1000] Saved checkpoint: testlog/ckpts/ckpt-2
Epoch 2 completed.
_measure_performance(dataset)[source]

Measure and log metrics on the dataset.

_train_step[source]

The training step that uses the distribution strategy.

call(train_set, validation_set)[source]

Start the training.

Parameters
train_step(features, labels)[source]

Train step.

Parameters
  • features – Input features.

  • labels – The labels.

Returns

Loss value.