Source code for ashpy.restorers.restorer

# Copyright 2020 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Primitive Restorer, can be used standalone."""

import json
from pathlib import Path
from typing import Dict, List, Optional, Union

import ashpy
import tensorflow as tf

__ALL__ = ["Restorer, ModelNotConstructedError"]


[docs]class ModelNotConstructedError(Exception): """ Exception raised while restoring sub-classed Model before having called it on data. Warning: When restoring a :class:`tf.keras.Model` object from checkpoint assure that the model has been correctly built and instantiated by firstly calling it on some sample inputs. In the case of a model built with either the Sequential or Functional API an exception will be raised; for a model built with the Chainer API it will fail silently, restoration will be "successful" but no values will actually be restored since there are no valid placeholder as the model has not be built yet. """
[docs]class Restorer: r""" :class:`Restorer` provide a way to restore objects from :class:`tf.train.Checkpoint`. Can be standalone. """
[docs] def __init__( self, logdir: Union[Path, str] = Path().cwd() / "log", ckpts_dir: str = "ckpts", expect_partial: bool = True, ) -> None: """ Initialize the Restorer. Args: logdir (str): Path to the directory with the logs. ckpts_dir (str): Name of the directory with the checkpoints to restore. expect_partial (bool): Whether to expect partial restoring or not. Default to true. For more information see the docs for :py:func:`tf.train.Checkpoint.restore()`. """ self._ckpts_dir = Path(logdir) / ckpts_dir if not self._ckpts_dir.exists(): raise FileNotFoundError(f"{ckpts_dir} does not exist.") self._restored_log_msg = "Restored {} from checkpoint {}." try: self._human_checkpoint_map: Optional[ Dict[str, str] ] = self._read_human_checkpoint_map() except FileNotFoundError: self._human_checkpoint_map = None
@property def checkpoint_map(self) -> Optional[Dict[str, str]]: """ Get the map of the ids in the checkpoint. Map is a Dict where keys are the `ids` in the checkpoint and the values are the string representation of the types. Returns: Dict if the map is found, else None. """ return self._human_checkpoint_map
[docs] def _restore_checkpoint(self, checkpoint, partial: bool = True): """Restore or initialize the persistence layer (checkpoint).""" manager = tf.train.CheckpointManager(checkpoint, self._ckpts_dir, max_to_keep=3) if not manager.latest_checkpoint: raise FileNotFoundError( f"Could not find any checkpoint in {self._ckpts_dir}." ) status = checkpoint.restore(manager.latest_checkpoint) if partial: status = status.expect_partial() status.assert_existing_objects_matched() return status
@staticmethod def _validate_placeholder(placeholder: List, placeholder_type): # We do a preliminary check on types since the error thrown by TF can be hard to parse. try: assert isinstance(placeholder, placeholder_type) except AssertionError: raise TypeError( f"Object {placeholder} is should be of type: {placeholder_type}" )
[docs] @staticmethod def _check_model_construction(restored_model: tf.keras.Model) -> bool: """ Optimistically check that the model.weights property returns a non empty-list. The underlying assumption is that Models created via the sub-classing API, when restored without being properly constructed AKA called on some input, will have empty lists as layers.weights. TODO: add docs for the exception. TODO: add test case for the Sequential without input shape """ try: if restored_model.weights == []: raise ModelNotConstructedError except AttributeError: # A Sequential() buil without specifiyng the input shape can be treated as a # sub-classed model for restoration purposes. raise ModelNotConstructedError return True
[docs] def restore_object(self, placeholder, object_ckpt_id: str): """ Restore a placeholder from a checkpoint using the specified id. Warning: When restoring a :class:`tf.keras.Model` object from checkpoint assure that the model has been correctly built and instantiated by firstly calling it on some sample inputs. In the case of a model built with either the Sequential or Functional API an exception will be raised; for a model built with the Chainer API it will fail silently, restoration will be "successful" but no values will actually be restored since there are no valid placeholder as the model has not be built yet. TODO: Args TODO: Example """ checkpoint = tf.train.Checkpoint(**{object_ckpt_id: placeholder}) status = self._restore_checkpoint(checkpoint) if isinstance(placeholder, tf.keras.Model): assert self._check_model_construction(placeholder) print(self._restored_log_msg.format(object_ckpt_id, self._ckpts_dir)) return status
# The following methods are provided as convenience since these objects are stored in # the Checkpoint by the Trainer.
[docs] def get_global_step(self) -> tf.Variable: """Return the restored global_step.""" placeholder = tf.Variable( -1, name="global_step", trainable=False, dtype=tf.int64 ) assert self.restore_object( placeholder, ashpy.trainers.Trainer.ckpt_id_global_step ) return placeholder
[docs] def get_steps_per_epoch(self) -> tf.Variable: """Return the restored global_step.""" placeholder = tf.Variable( -1, name="steps_per_epoch", trainable=False, dtype=tf.int64 ) assert self.restore_object( placeholder, ashpy.trainers.Trainer.ckpt_id_steps_per_epoch ) return placeholder
[docs] def restore_callback( self, callback: ashpy.callbacks.Callback, callback_ckpt_id: str ) -> List[ashpy.callbacks.Callback]: """Return the restored callbacks.""" self._validate_placeholder(callback, ashpy.callbacks.Callback) assert self.restore_object(callback, callback_ckpt_id) return callback
def _read_human_checkpoint_map(self) -> Dict[str, str]: with open(self._ckpts_dir / "checkpoint_map.json") as fp: return json.load(fp)