Source code for ashpy.metrics.classifier

# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""The classification metrics."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Union

import tensorflow as tf  # pylint: disable=import-error
from ashpy.metrics.metric import Metric
from ashpy.modes import LogEvalMode

    from ashpy.contexts import ClassifierContext  # pylint: disable=ungrouped-imports

    TPRocessingPredictions = Dict[str, Union[Callable, Dict[str, Any]]]

__ALL__ = ["ClassifierLoss", "ClassifierMetric"]

[docs]class ClassifierLoss(Metric): """A handy way to measure the classification loss."""
[docs] def __init__( self, name: str = "loss", model_selection_operator: Callable = None, logdir: Union[Path, str] = Path().cwd() / "log", ) -> None: """ Initialize the Metric. Args: name (str): Name of the metric. model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an operator is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. """ super().__init__( name=name, metric=tf.metrics.Mean(name=name, dtype=tf.float32), model_selection_operator=model_selection_operator, logdir=logdir, )
[docs] def update_state(self, context: ClassifierContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for features, labels in context.dataset: loss = context.loss( context, features=features, labels=labels, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run_v2(updater(loss))
[docs]class ClassifierMetric(Metric): """Wrap a metric using `argmax` to extract predictions out of a classifier's output."""
[docs] def __init__( self, metric: tf.keras.metrics.Metric, model_selection_operator: Callable = None, logdir: Union[Path, str] = Path().cwd() / "log", processing_predictions=None, ) -> None: """ Initialize the Metric. Args: metric (:py:class:`tf.keras.metrics.Metric`): The Keras Metric to use with the classifier (e.g.: Accuracy()). model_selection_operator (:py:obj:`typing.Callable`): The operation that will be used when `model_selection` is triggered to compare the metrics, used by the `update_state`. Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted. .. note:: Model selection is done ONLY if an `model_selection_operator` is specified here. logdir (str): Path to the log dir, defaults to a `log` folder in the current directory. processing_predictions (:py:obj:`typing.Dict`): A `dict` in the form of `{"fn": tf.argmax, "kwargs": {"axis": -1}}` with a function `"fn"` to be used for predictions processing purposes and its `"kwargs"` as its keyword-arguments. Defaults to {"fn": tf.argmax, "kwargs": {"axis": -1}}. """ super().__init__(, metric=metric, model_selection_operator=model_selection_operator, logdir=logdir, ) if processing_predictions is None: processing_predictions = {"fn": tf.argmax, "kwargs": {"axis": -1}} self._processing_predictions = processing_predictions
[docs] def update_state(self, context: ClassifierContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ for features, labels in context.dataset: predictions = context.classifier_model( features, training=context.log_eval_mode == LogEvalMode.TRAIN ) self._distribute_strategy.experimental_run( lambda: self._metric.update_state( labels, self._processing_predictions["fn"]( predictions, **self._processing_predictions["kwargs"] ), ) )