Source code for ashpy.metrics.metric

# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Metric is the abstract class that every ash metric must implement."""

from __future__ import annotations

import errno
import json
import operator
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

import numpy as np
import tensorflow as tf  # pylint: disable=import-error

if TYPE_CHECKING:
    from ashpy.contexts import Context

__ALL__ = ["Metric"]


[docs]class Metric(ABC): """ Metric is the abstract class that every AshPy Metric must implement. AshPy Metric wrap and extend :class:`tf.keras.metrics.Metric`. """
[docs] def __init__( self, name: str, metric: tf.keras.metrics.Metric, model_selection_operator: Callable = None, logdir: Union[Path, str] = Path.cwd() / "log", ) -> None: """ Initialize the Metric object. Args: name (str): Name of the metric. metric (:py:class:`tf.keras.metrics.Metric`): The Keras metric to use. model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an `model_selection_operator` is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ self._distribute_strategy = tf.distribute.get_strategy() self._name = name self._metric = metric self._model_selection_operator = model_selection_operator self._logdir = Path(logdir) if not isinstance(logdir, Path) else logdir
[docs] def model_selection( self, checkpoint: tf.train.Checkpoint, global_step: tf.Variable ) -> Optional[Path]: """ Perform model selection. Args: checkpoint (:py:class:`tf.train.Checkpoint`): Checkpoint object that contains the model status. global_step (:py:class:`tf.Variable`): current training step """ current_value = self.result() previous_value = float( self.json_read(self.best_model_sel_file)[self.sanitized_name] ) # Model selection is done ONLY if an operator was passed at __init__ if ( self._model_selection_operator and self._model_selection_operator(current_value, previous_value) and not np.isclose(current_value, previous_value) ): print( f"{self.sanitized_name}: validation value: {previous_value}{current_value}", ) self.json_write( self.best_model_sel_file, { self.sanitized_name: str(current_value), "step": int(global_step.numpy()), }, ) manager = tf.train.CheckpointManager( checkpoint, self.best_folder / "ckpts", max_to_keep=1 ) return Path(manager.save()) return None
def _update_logdir(self): if not self._model_selection_operator: pass # write the initial value of the best metric if not self.best_model_sel_file.exists(): self.best_model_sel_file.parent.mkdir(parents=True) initial_value = ( np.inf if self._model_selection_operator is operator.lt else -np.inf ) self.json_write( self.best_model_sel_file, {self.sanitized_name: str(initial_value), "step": 0}, ) @property def name(self) -> str: """Retrieve the metric name.""" return self._name @property def sanitized_name(self) -> str: """ Retrieve the sanitized name: all / are _. This is done since adding a prefix to a metric name with a / allows for TensorBoard automatic grouping. When we are not working with TB we want to replace all / with _. """ return self._name.replace("/", "_") @property def metric(self) -> tf.keras.metrics.Metric: """Retrieve the :py:class:`tf.keras.metrics.Metric` object.""" return self._metric @property def model_selection_operator(self) -> Optional[Callable]: """Retrieve the operator used for model selection.""" return self._model_selection_operator @property def logdir(self) -> Path: """Retrieve the log directory.""" return self._logdir @logdir.setter def logdir(self, logdir) -> None: """Set the logdir changing also other properties.""" self._logdir = logdir self._update_logdir() @property def best_folder(self) -> Path: """Retrieve the folder used to save the best model when doing model selection.""" return self.logdir / "best" / self.sanitized_name @property def best_model_sel_file(self) -> Path: """Retrieve the path to JSON file containing the measured performance of the best model.""" return self.best_folder / (self.sanitized_name + ".json")
[docs] @staticmethod def json_read(filename: Path) -> Dict[str, Any]: """ Read a JSON file. Args: filename (str): The path to the JSON file to read. Returns: :py:obj:`typing.Dict`: Dictionary containing the content of the JSON file. """ if not filename.exists(): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), filename) data: Dict[str, Union[str, int, float]] = {} with open(filename, "r") as fp: data = json.load(fp) return data
[docs] @staticmethod def json_write(filename: Path, what_to_write: Dict) -> None: """ Write inside the specified JSON file the mean and stddev. Args: filename (str): Path to the JSON file to write. what_to_write (dict): A dictionary containing what to write. """ if filename.exists(): data = Metric.json_read(filename) for key in what_to_write: data[key] = str(what_to_write[key]) else: data = what_to_write if not filename.parent.exists(): filename.parent.mkdir() with open(filename, "w+") as fp: json.dump(data, fp, indent=4)
[docs] @abstractmethod def update_state(self, context: Context) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.Context`): An AshPy Context holding all the information the Metric needs. """
[docs] def result(self): """ Get the result of the metric. Returns: :py:class:`numpy.ndarray`: The current value of the metric. """ return self._metric.result().numpy()
[docs] def log(self, step: int) -> None: """ Log the metric. Args: step: global step of training """ tf.summary.scalar(self.name, self.result(), step=step)
[docs] def reset_states(self) -> None: """Reset the state of the metric.""" self._metric.reset_states()