Source code for ashpy.metrics.classifier

# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The classification metrics."""

from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Union

import tensorflow as tf  # pylint: disable=import-error

from ashpy.metrics.metric import Metric
from ashpy.modes import LogEvalMode

if TYPE_CHECKING:
    from ashpy.contexts import ClassifierContext  # pylint: disable=ungrouped-imports

    TPRocessingPredictions = Dict[str, Union[Callable, Dict[str, Any]]]


[docs]class ClassifierLoss(Metric): """A handy way to measure the classification loss."""
[docs] def __init__( self, model_selection_operator: Callable = None, logdir: str = os.path.join(os.getcwd(), "log"), ) -> None: """ Initialize the Metric. Args: model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an operator is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ super().__init__( name="loss", metric=tf.metrics.Mean(name="loss", dtype=tf.float32), model_selection_operator=model_selection_operator, logdir=logdir, )
[docs] def update_state(self, context: ClassifierContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for features, labels in context.dataset: loss = context.loss( context, features=features, labels=labels, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run_v2(updater(loss))
[docs]class ClassifierMetric(Metric): """Wrap a metric using `argmax` to extract predictions out of a classifier's output."""
[docs] def __init__( self, metric: tf.keras.metrics.Metric, model_selection_operator: Callable = None, logdir: str = os.path.join(os.getcwd(), "log"), processing_predictions=None, ) -> None: """ Initialize the Metric. Args: metric (:py:class:`tf.keras.metrics.Metric`): The Keras Metric to use with the classifier (e.g.: Accuracy()). model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an `model_selection_operator` is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. processing_predictions (:py:obj:`typing.Dict`): A `dict` in the form of `{"fn": tf.argmax, "kwargs": {"axis": -1}}` with a function `"fn"` to be used for predictions processing purposes and its `"kwargs"` as its keyword-arguments. Defaults to {"fn": tf.argmax, "kwargs": {"axis": -1}}. """ super().__init__( name=metric.name, metric=metric, model_selection_operator=model_selection_operator, logdir=logdir, ) if processing_predictions is None: processing_predictions = {"fn": tf.argmax, "kwargs": {"axis": -1}} self._processing_predictions = processing_predictions
[docs] def update_state(self, context: ClassifierContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ for features, labels in context.dataset: predictions = context.classifier_model( features, training=context.log_eval_mode == LogEvalMode.TRAIN ) self._distribute_strategy.experimental_run( lambda: self._metric.update_state( labels, self._processing_predictions["fn"]( predictions, **self._processing_predictions["kwargs"] ), ) )