Advanced AshPy¶
Custom Metrics¶
AshPy Trainers can accept metrics that they will use for both logging and automatic model selection.
Implementing a custom Metric in AshPy can be done via two approach:
- Your metric is already available as a
tf.keras.metrics.Metric
and you want to use it as is. - You need to write the implementation of the Metric from scratch or you need to alter the default behavior we provide for AshPy Metrics.
Wrapping Keras Metrics¶
In case number (1) what you want to do is to search for one of the Metrics provided by AshPy and use it as a wrapper around the one you wish to use.
Note
Passing an operator
funciton to the AshPy Metric will enable model selection using the
metric value.
The example below shows how to implement the Precision metric for an ClassifierTrainer
.
import operator
from ashpy.metrics import ClassifierMetric
from ashpy.trainers import ClassifierTrainer
from tensorflow.keras.metrics import Precision
precision = ClassifierMetric(
metric=tf.keras.metrics.Precision(),
model_selection_operator=operator.gt,
logdir=Path().cwd() / "log",
)
trainer = ClassifierTrainer(
...
metrics = [precision]
...
)
You can apply this technique to any object derived and behaving as a tf.keras.metrics.Metric
(i.e. the Metrics present in TensorFlow Addons)
Creating your own Metric¶
As an example of a custom Metric we present the analysis of the ashpy.metrics.classifier.ClassifierLoss
.
class ClassifierLoss(Metric):
"""A handy way to measure the classification loss."""
def __init__(
self,
name: str = "loss",
model_selection_operator: Callable = None,
logdir: Union[Path, str] = Path().cwd() / "log",
) -> None:
"""
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted.
.. note::
Model selection is done ONLY if an operator is specified here.
logdir (str): Path to the log dir, defaults to a `log` folder in the current
directory.
"""
super().__init__(
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
def update_state(self, context: ClassifierContext) -> None:
"""
Update the internal state of the metric, using the information from the context object.
Args:
context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
holding all the information the Metric needs.
"""
updater = lambda value: lambda: self._metric.update_state(value)
for features, labels in context.dataset:
loss = context.loss(
context,
features=features,
labels=labels,
training=context.log_eval_mode == LogEvalMode.TRAIN,
)
self._distribute_strategy.experimental_run_v2(updater(loss))
- Each custom Metric should always inherit from
ashpy.metrics.Metric
. - We advise that each custom Metric respescts the base
ashpy.metrics.metric.Metric.__init__()
- Inside the
super()
call be sure to provide one of thetf.keras.metrics
primitive metrics (i.e.tf.keras.metrics.Mean
,tf.keras.metrics.Sum
).
Warning
The name
argument of the ashpy.metrics.metric.Metric.__init__()
is a str
identifier
which should be unique across all the metrics used by your Trainer
.
Custom Computation inside Metric.update_state()¶
- This method is invoked during the training and receives a
Context
. - In this example, since we are working under the
ClassifierTrainer
we are using anClassifierContext
. For more information on theContext
family of objects see AshPy Internals. - Inside this update_state state we won’t be doing any fancy computation, we just retrieve
the loss value from the
ClassifierContext
and then we call theupdater
lambda from the fetched distribution strategy. - The active distribution strategy is automatically retrieved during the
super()
, this guarantees that every object derived from anashpy.metrics.Metric
will work flawlessly even in a distributed environment. ashpy.metrics.metric.Metric.metric
(here referenced asself._metric
is the primitivetf.keras.metrics.Metric
whoseupadate_state()
method we will be using to simplify our operations.- Custom computation will almost always be done via iteration over the data offered by the
Context
.
For a much more complex (but probably exhaustive) example have a look at the source code
of ashpy.metrics.SlicedWassersteinDistance
.
Custom Callbacks¶
Our Callback
is built on the same base structure as a tf.keras.callbacks.Callback
exposing methods
acting as hooks for the same events.
- on_train_start
- on_epoch_start
- on_batch_start
- on_batch_end
- on_epoch_end
- on_train_end
Inside the ashpy.callbacks
module we offer two primitive Callbacks classes to inherit from.
ashpy.callbacks.Callback
: is the most basic form of callback and the basic block for all the other.CounterCallback
: is derived fromashpy.callbacks.Callback
and contains built-in logic for triggering an event given a desired frequency.
Let’s take a look at the following example which is the callback used to log GANs output to
TensorBoard - ashpy.callbacks.gan.LogImageGANCallback
class LogImageGANCallback(CounterCallback):
def __init__(
self,
event: Event = Event.ON_EPOCH_END,
name: str = "log_image_gan_callback",
event_freq: int = 1,
) -> None:
"""
Initialize the LogImageCallbackGAN.
Args:
event (:py:class:`ashpy.callbacks.events.Event`): event to consider.
event_freq (int): frequency of logging.
name (str): name of the callback.
"""
super(LogImageGANCallback, self).__init__(
event=event, fn=self._log_fn, name=name, event_freq=event_freq
)
def _log_fn(self, context: GANContext) -> None:
"""
Log output of the generator to Tensorboard.
Args:
context (:py:class:`ashpy.contexts.gan.GANContext`): current context.
"""
if context.log_eval_mode == LogEvalMode.TEST:
out = context.generator_model(context.generator_inputs, training=False)
elif context.log_eval_mode == LogEvalMode.TRAIN:
out = context.fake_samples
else:
raise ValueError("Invalid LogEvalMode")
log("generator", out, context.global_step)
Let’s start with the __init__()
function, as for the Custom ashpy.metrics.Metric
when
inheriting from either Callback
or CounterCallback
respect the common part of the signature:
event
: In AshPy we use an Enum -ashpy.callbacks.Event
- to choose the event type you want theCallback
to be triggered on.name
: Uniquestr
identifier for theCallback
event_freq
: Simpleint
specifying the frequency.fn
: Acallable()
this is the function that gets triggered. Inside AshPy we converged on using a private method called_log_fn()
in each of our derived Callbacks. Whatever approach you choose, the function fed tofn
should have aContext
as input. For more information on theContext
family of objects see AshPy Internals.
Warning
The name
argument of the ashpy.callbacks.callback.Callback.__init__()
is a str
identifier
which should be unique across all the callbacks used by your Trainer
.