classifier¶
Primitive Trainer Interface.
Classes
|
-
class
ashpy.trainers.classifier.
ClassifierTrainer
(model, optimizer, loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]¶ Bases:
ashpy.trainers.base_trainer.BaseTrainer
ClassifierTrainer
provide the standard training loop for a classifier.-
__init__
(model, optimizer, loss, epochs, metrics=None, logdir='/home/docs/checkouts/readthedocs.org/user_builds/ashpy/checkouts/v0.1.3/docs/source/log', global_step=<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>, post_process_callback=None)[source]¶ Instantiate the
ClassifierTrainer
trainer.- Parameters
model (
tf.keras.Model
) – Atf.keras.Model
model.optimizer (
tf.optimizers.Optimizer
) – Atf.optimizers.Optimizer
.loss (
callable
) – A loss function built followingtf.losses
.epochs (int) – Number of training epochs.
metrics – (List): List of python objects (dictionaries or tf.metrics objects) to measure on training and validation data.
logdir (str) – Checkpoint and log directory.
global_step – tf.Variable that keeps track of the training steps.
post_process_callback (
callable
) – the function to postprocess the model output, if needed.
Examples
def toy_dataset(): inputs = tf.expand_dims(tf.range(1, 1000.0), -1) labels = tf.expand_dims( [1 if tf.equal(tf.math.mod(tf.squeeze(i), 2), 0) else 0 for i in inputs], -1 ) return tf.data.Dataset.from_tensor_slices((inputs, labels)).shuffle(10).batch(2) model = tf.keras.Sequential( [tf.keras.layers.Dense(10, activation=tf.nn.sigmoid), tf.keras.layers.Dense(2)] ) optimizer = tf.optimizers.Adam(1e-3) loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True)) logdir = "testlog" epochs = 2 metrics = [ ClassifierMetric(tf.metrics.Accuracy()), ClassifierMetric(tf.metrics.BinaryAccuracy()), ] trainer = ClassifierTrainer(model, optimizer, loss, epochs, metrics, logdir=logdir) train, validation = toy_dataset(), toy_dataset() trainer(train, validation) shutil.rmtree(logdir)
Initializing checkpoint. [500] Saved checkpoint: testlog/ckpts/ckpt-1 Epoch 1 completed. [1000] Saved checkpoint: testlog/ckpts/ckpt-2 Epoch 2 completed.
-
call
(train_set, validation_set)[source]¶ Start the training.
- Parameters
train_set (
tf.data.Dataset
) – Training dataset.validation_set (
tf.data.Dataset
) – Validation dataset.
-