classifier¶
Primitive Trainer Interface.
Classes
|
-
class
ashpy.trainers.classifier.
ClassifierTrainer
(model, optimizer, loss, epochs, metrics=None, callbacks=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.2.0/docs/source/log', global_step=None)[source]¶ Bases:
ashpy.trainers.trainer.Trainer
ClassifierTrainer
provide the standard training loop for a classifier.-
__init__
(model, optimizer, loss, epochs, metrics=None, callbacks=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.2.0/docs/source/log', global_step=None)[source]¶ Instantiate the
ClassifierTrainer
trainer.- Parameters
model (
tf.keras.Model
) – Atf.keras.Model
model.optimizer (
tf.optimizers.Optimizer
) – Atf.optimizers.Optimizer
.loss (
ashpy.losses.classifier.ClassifierLoss
) – A loss function built followingashpy.executors`
.epochs (int) – Number of training epochs.
metrics (
Optional
[List
[Metric
]]) – (List): List ofashpy.metrics.metric.Metric
to measure on training and validation data.callbacks (List) – List of
ashpy.callbacks.callback.Callback
to to call on eventslogdir (str) – Checkpoint and log directory.
global_step (Optional[py:class:tf.Variable]) – tf.Variable that keeps track of the training steps.
Examples
def toy_dataset(): inputs = tf.expand_dims(tf.range(1, 1000.0), -1) labels = tf.expand_dims( [1 if tf.equal(tf.math.mod(tf.squeeze(i), 2), 0) else 0 for i in inputs], -1 ) return tf.data.Dataset.from_tensor_slices((inputs, labels)).shuffle(10).batch(2) model = tf.keras.Sequential( [tf.keras.layers.Dense(10, activation=tf.nn.sigmoid), tf.keras.layers.Dense(2)] ) optimizer = tf.optimizers.Adam(1e-3) loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True)) logdir = "testlog" epochs = 2 metrics = [ ClassifierMetric(tf.metrics.Accuracy()), ClassifierMetric(tf.metrics.BinaryAccuracy()), ] trainer = ClassifierTrainer(model=model, optimizer=optimizer, loss=loss, epochs=epochs, metrics=metrics, logdir=logdir) train, validation = toy_dataset(), toy_dataset() trainer(train, validation) shutil.rmtree(logdir)
Initializing checkpoint. Starting epoch 1. [500] Saved checkpoint: testlog/ckpts/ckpt-1 Epoch 1 completed. Starting epoch 2. [1000] Saved checkpoint: testlog/ckpts/ckpt-2 Epoch 2 completed. Training finished after 2 epochs.
-
call
(training_set, validation_set, log_freq=10, measure_performance_freq=10)[source]¶ Start the training.
- Parameters
training_set (
tf.data.Dataset
) – Training dataset.validation_set (
tf.data.Dataset
) – Validation dataset.log_freq (int) – Specifies how many steps to run before logging the losses, e.g. log_frequency=10 logs every 10 steps of training. Pass log_frequency<=0 in case you don’t want to log.
measure_performance_freq (int) – Specifies how many steps to run before measuring the performance, e.g. measure_performance_freq=10 measures performance every 10 steps of training. Pass measure_performance_freq<=0 in case you don’t want to measure performance.
-