SaveCallback¶
Inheritance Diagram
-
class
ashpy.callbacks.save_callback.
SaveCallback
(save_dir, models, event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, event_freq=1, max_to_keep=1, save_format=<SaveFormat.MODEL|WEIGHTS: 3>, save_sub_format=<SaveSubFormat.TF: 'tf'>, verbose=0, name='SaveCallback')[source]¶ Bases:
ashpy.callbacks.counter_callback.CounterCallback
Save Callback implementation.
Examples
import shutil import operator import os generator = models.gans.ConvGenerator( layer_spec_input_res=(7, 7), layer_spec_target_res=(28, 28), kernel_size=(5, 5), initial_filters=32, filters_cap=16, channels=1, ) discriminator = models.gans.ConvDiscriminator( layer_spec_input_res=(28, 28), layer_spec_target_res=(7, 7), kernel_size=(5, 5), initial_filters=16, filters_cap=32, output_shape=1, ) models = [generator, discriminator] save_callback = callbacks.SaveCallback(save_dir="testlog/savedir", models=models, save_format=callbacks.SaveFormat.WEIGHTS, save_sub_format=callbacks.SaveSubFormat.TF) # initialize trainer passing the save_callback
Methods
__init__
(save_dir, models[, event, …])Build a Save Callback. save_weights_fn
(context)Save weights and clean up if needed. Attributes
name
Return the name of the callback. name_scope
Returns a tf.name_scope instance for this class. submodules
Sequence of all sub-modules. trainable_variables
Sequence of variables owned by this module and it’s submodules. variables
Sequence of variables owned by this module and it’s submodules. -
__init__
(save_dir, models, event=<Event.ON_EPOCH_END: 'ON_EPOCH_END'>, event_freq=1, max_to_keep=1, save_format=<SaveFormat.MODEL|WEIGHTS: 3>, save_sub_format=<SaveSubFormat.TF: 'tf'>, verbose=0, name='SaveCallback')[source]¶ Build a Save Callback.
Save Callbacks are used to save the model on events. You can specify two different save formats: weights and model. At the same time you can specify two different save sub-formats: tf or h5. You will find the model saved in the save_dir under the directory weights or model.
Parameters: - save_dir (str) – directory in which to save the weights or the model.
- models (List[
tf.keras.models.Model
]) – list of models to save. - event (
ashpy.callbacks.events.Event
) – events on which to trigger the saving operation. - event_freq (int) – frequency of saving operation.
- name (str) – name of the callback.
- verbose (int) – verbosity of the callback (0 or 1).
- max_to_keep (int) – maximum files to keep. If max_to_keep == 1 only the most recent file is kept. In general max_to_keep files are kept.
- save_format (
ashpy.callbacks.save_callback.SaveFormat
) – weights or model. - save_sub_format (
ashpy.callbacks.save_callback.SaveSubFormat
) – sub-format of the saving (tf or h5).
-