save_callback

Save weights callback.

Classes

SaveCallback

Save Callback implementation.

SaveFormat

Save Format enum.

SaveSubFormat

Save Sub-Format enum.

class ashpy.callbacks.save_callback.SaveCallback(save_dir, models, event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, event_freq=1, max_to_keep=1, save_format=<SaveFormat.MODEL|WEIGHTS: 3>, save_sub_format=<SaveSubFormat.TF: 'tf'>, verbose=0, name='SaveCallback')[source]

Bases: ashpy.callbacks.counter_callback.CounterCallback

Save Callback implementation.

Examples

import shutil
import operator
import os

generator = models.gans.ConvGenerator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = models.gans.ConvDiscriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

models = [generator, discriminator]

save_callback = callbacks.SaveCallback(save_dir="testlog/savedir",
                                       models=models,
                                       save_format=callbacks.SaveFormat.WEIGHTS,
                                       save_sub_format=callbacks.SaveSubFormat.TF)

# initialize trainer passing the save_callback
__init__(save_dir, models, event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, event_freq=1, max_to_keep=1, save_format=<SaveFormat.MODEL|WEIGHTS: 3>, save_sub_format=<SaveSubFormat.TF: 'tf'>, verbose=0, name='SaveCallback')[source]

Build a Save Callback.

Save Callbacks are used to save the model on events. You can specify two different save formats: weights and model. At the same time you can specify two different save sub-formats: tf or h5. You will find the model saved in the save_dir under the directory weights or model.

Parameters
  • save_dir (str) – directory in which to save the weights or the model.

  • models (List[tf.keras.models.Model]) – list of models to save.

  • event (ashpy.callbacks.events.Event) – events on which to trigger the saving operation.

  • event_freq (int) – frequency of saving operation.

  • name (str) – name of the callback.

  • verbose (int) – verbosity of the callback (0 or 1).

  • max_to_keep (int) – maximum files to keep. If max_to_keep == 1 only the most recent file is kept. In general max_to_keep files are kept.

  • save_format (ashpy.callbacks.save_callback.SaveFormat) – weights or model.

  • save_sub_format (ashpy.callbacks.save_callback.SaveSubFormat) – sub-format of the saving (tf or h5).

_cleanup()[source]

Cleanup stuff.

_save_weights_fn(step)[source]

Save weights.

Parameters

step (int) – current step.

save_weights_fn(context)[source]

Save weights and clean up if needed.

class ashpy.callbacks.save_callback.SaveFormat[source]

Bases: enum.Flag

Save Format enum.

MODEL = 2

Model format (weights and architecture), saved using model.save()

WEIGHTS = 1

Weights format, saved using model.save_weights()

name()[source]

Name of the format.

Return type

str

save(model, save_dir, save_sub_format=<SaveSubFormat.TF: 'tf'>)[source]

Save the model using the correct format and sub-format.

Parameters
Return type

None

class ashpy.callbacks.save_callback.SaveSubFormat[source]

Bases: enum.Enum

Save Sub-Format enum.

H5 = 'h5'

H5 Format

TF = 'tf'

TensorFlow format