Advanced AshPy

Custom Metrics

AshPy Trainers can accept metrics that they will use for both logging and automatic model selection.

Implementing a custom Metric in AshPy can be done via two approach:

  1. Your metric is already available as a tf.keras.metrics.Metric and you want to use it as is.
  2. You need to write the implementation of the Metric from scratch or you need to alter the default behavior we provide for AshPy Metrics.

Wrapping Keras Metrics

In case number (1) what you want to do is to search for one of the Metrics provided by AshPy and use it as a wrapper around the one you wish to use.


Passing an operator funciton to the AshPy Metric will enable model selection using the metric value.

The example below shows how to implement the Precision metric for an ClassifierTrainer.

import operator

from ashpy.metrics import ClassifierMetric
from ashpy.trainers import ClassifierTrainer
from tensorflow.keras.metrics import Precision

precision = ClassifierMetric(
    logdir=Path().cwd() / "log",

trainer = ClassifierTrainer(
    metrics = [precision]

You can apply this technique to any object derived and behaving as a tf.keras.metrics.Metric (i.e. the Metrics present in TensorFlow Addons)

Creating your own Metric

As an example of a custom Metric we present the analysis of the ashpy.metrics.classifier.ClassifierLoss.

class ClassifierLoss(Metric):
    """A handy way to measure the classification loss."""

    def __init__(
        name: str = "loss",
        model_selection_operator: Callable = None,
        logdir: Union[Path, str] = Path().cwd() / "log",
    ) -> None:
        Initialize the Metric.

            name (str): Name of the metric.
            model_selection_operator (:py:obj:`typing.Callable`): The operation that will
                be used when `model_selection` is triggered to compare the metrics,
                used by the `update_state`.
                Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted.
                .. note::
                    Model selection is done ONLY if an operator is specified here.
            logdir (str): Path to the log dir, defaults to a `log` folder in the current

            metric=tf.metrics.Mean(name=name, dtype=tf.float32),

    def update_state(self, context: ClassifierContext) -> None:
        Update the internal state of the metric, using the information from the context object.
            context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
                holding all the information the Metric needs.

        updater = lambda value: lambda: self._metric.update_state(value)
        for features, labels in context.dataset:
            loss = context.loss(
                training=context.log_eval_mode == LogEvalMode.TRAIN,


The name argument of the ashpy.metrics.metric.Metric.__init__() is a str identifier which should be unique across all the metrics used by your Trainer.

Custom Computation inside Metric.update_state()

  • This method is invoked during the training and receives a Context.
  • In this example, since we are working under the ClassifierTrainer we are using an ClassifierContext. For more information on the Context family of objects see AshPy Internals.
  • Inside this update_state state we won’t be doing any fancy computation, we just retrieve the loss value from the ClassifierContext and then we call the updater lambda from the fetched distribution strategy.
  • The active distribution strategy is automatically retrieved during the super(), this guarantees that every object derived from an ashpy.metrics.Metric will work flawlessly even in a distributed environment.
  • ashpy.metrics.metric.Metric.metric (here referenced as self._metric is the primitive tf.keras.metrics.Metric whose upadate_state() method we will be using to simplify our operations.
  • Custom computation will almost always be done via iteration over the data offered by the Context.

For a much more complex (but probably exhaustive) example have a look at the source code of ashpy.metrics.SlicedWassersteinDistance.

Custom Callbacks

Our Callback is built on the same base structure as a tf.keras.callbacks.Callback exposing methods acting as hooks for the same events.

  • on_train_start
  • on_epoch_start
  • on_batch_start
  • on_batch_end
  • on_epoch_end
  • on_train_end

Inside the ashpy.callbacks module we offer two primitive Callbacks classes to inherit from.

  1. ashpy.callbacks.Callback: is the most basic form of callback and the basic block for all the other.
  2. CounterCallback: is derived from ashpy.callbacks.Callback and contains built-in logic for triggering an event given a desired frequency.

Let’s take a look at the following example which is the callback used to log GANs output to TensorBoard - ashpy.callbacks.gan.LogImageGANCallback

class LogImageGANCallback(CounterCallback):
    def __init__(
        event: Event = Event.ON_EPOCH_END,
        name: str = "log_image_gan_callback",
        event_freq: int = 1,
    ) -> None:
        Initialize the LogImageCallbackGAN.

            event (:py:class:``): event to consider.
            event_freq (int): frequency of logging.
            name (str): name of the callback.

        super(LogImageGANCallback, self).__init__(
            event=event, fn=self._log_fn, name=name, event_freq=event_freq

    def _log_fn(self, context: GANContext) -> None:
        Log output of the generator to Tensorboard.

            context (:py:class:`ashpy.contexts.gan.GANContext`): current context.

        if context.log_eval_mode == LogEvalMode.TEST:
            out = context.generator_model(context.generator_inputs, training=False)
        elif context.log_eval_mode == LogEvalMode.TRAIN:
            out = context.fake_samples
            raise ValueError("Invalid LogEvalMode")

        log("generator", out, context.global_step)

Let’s start with the __init__() function, as for the Custom ashpy.metrics.Metric when inheriting from either Callback or CounterCallback respect the common part of the signature:

  • event: In AshPy we use an Enum - ashpy.callbacks.Event - to choose the event type you want the Callback to be triggered on.
  • name: Unique str identifier for the Callback
  • event_freq: Simple int specifying the frequency.
  • fn: A callable() this is the function that gets triggered. Inside AshPy we converged on using a private method called _log_fn() in each of our derived Callbacks. Whatever approach you choose, the function fed to fn should have a Context as input. For more information on the Context family of objects see AshPy Internals.


The name argument of the ashpy.callbacks.callback.Callback.__init__() is a str identifier which should be unique across all the callbacks used by your Trainer.