SaveCallback

Inheritance Diagram

Inheritance diagram of ashpy.callbacks.save_callback.SaveCallback

class ashpy.callbacks.save_callback.SaveCallback(save_dir, models, event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, event_freq=1, max_to_keep=1, save_format=<SaveFormat.MODEL|WEIGHTS: 3>, save_sub_format=<SaveSubFormat.TF: 'tf'>, verbose=0, name='SaveCallback')[source]

Bases: ashpy.callbacks.counter_callback.CounterCallback

Save Callback implementation.

Examples

import shutil
import operator
import os

generator = models.gans.ConvGenerator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = models.gans.ConvDiscriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

models = [generator, discriminator]

save_callback = callbacks.SaveCallback(save_dir="testlog/savedir",
                                       models=models,
                                       save_format=callbacks.SaveFormat.WEIGHTS,
                                       save_sub_format=callbacks.SaveSubFormat.TF)

# initialize trainer passing the save_callback

Methods

__init__(save_dir, models[, event, …]) Build a Save Callback.
save_weights_fn(context) Save weights and clean up if needed.

Attributes

name Returns the name of this module as passed or determined in the ctor.
name_scope Returns a tf.name_scope instance for this class.
submodules Sequence of all sub-modules.
trainable_variables Sequence of variables owned by this module and it’s submodules.
variables Sequence of variables owned by this module and it’s submodules.
__init__(save_dir, models, event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, event_freq=1, max_to_keep=1, save_format=<SaveFormat.MODEL|WEIGHTS: 3>, save_sub_format=<SaveSubFormat.TF: 'tf'>, verbose=0, name='SaveCallback')[source]

Build a Save Callback.

Save Callbacks are used to save the model on events. You can specify two different save formats: weights and model. At the same time you can specify two different save sub-formats: tf or h5. You will find the model saved in the save_dir under the directory weights or model.

Parameters:
  • save_dir (str) – directory in which to save the weights or the model.
  • models (List[tf.keras.models.Model]) – list of models to save.
  • event (ashpy.callbacks.events.Event) – events on which to trigger the saving operation.
  • event_freq (int) – frequency of saving operation.
  • name (str) – name of the callback.
  • verbose (int) – verbosity of the callback (0 or 1).
  • max_to_keep (int) – maximum files to keep. If max_to_keep == 1 only the most recent file is kept. In general max_to_keep files are kept.
  • save_format (ashpy.callbacks.save_callback.SaveFormat) – weights or model.
  • save_sub_format (ashpy.callbacks.save_callback.SaveSubFormat) – sub-format of the saving (tf or h5).
_cleanup()[source]

Cleanup stuff.

_save_weights_fn(step)[source]

Save weights.

Parameters:step (int) – current step.
save_weights_fn(context)[source]

Save weights and clean up if needed.