# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Executor.
An object that, given an :py:class:`ashpy.contexts.Context`, carries a
function and the way of executing it.
"""
from __future__ import annotations
import abc
from typing import Callable, List, Union
import tensorflow as tf
[docs]class Executor:
"""Carry a function and the way of executing it. Given a context."""
[docs] def __init__(self, fn: tf.keras.losses.Loss = None) -> None:
"""
Initialize the Executor.
Args:
fn (:py:class:`tf.keras.losses.Loss`): A Keras Loss to execute.
Returns:
:py:obj:`None`
"""
if fn is not None:
assert isinstance(fn, tf.keras.losses.Loss)
self._fn = fn
# We always work as in a strategy context
self._fn.reduction = tf.keras.losses.Reduction.NONE
self._distribute_strategy = tf.distribute.get_strategy()
self._global_batch_size = -1
self._weight = lambda _: 1.0
@property
def weight(self) -> Callable[..., float]:
"""
Return the loss weight.
This weight is multiplied by the loss value.
This is useful when working with multiples losses.
Returns:
:py:obj:`typing.Callable`: Callable returning the weight (:py:obj:`float`).
"""
return self._weight
@property
def fn(self) -> tf.keras.losses.Loss: # pylint: disable=invalid-name
"""
Return the Keras loss function to execute.
Returns:
:py:obj:`tf.keras.losses.Loss`: Keras Loss.
"""
return self._fn
[docs] @staticmethod
def reduce_loss(call_fn: Callable) -> Callable:
"""
Create a Decorator to reduce Losses. Used to simplify things.
Apply a ``reduce sum`` operation to the loss and divide the result
by the batch size.
Args:
call_fn (:py:obj:`typing.Callable`): The executor call method.
Return:
:py:obj:`typing.Callable`: The decorated function.
"""
# decorator definition
def _reduce(self, *args, **kwargs):
return tf.nn.compute_average_loss(
call_fn(self, *args, **kwargs),
global_batch_size=self._global_batch_size, # pylint: disable=protected-access
)
return _reduce
@property
def global_batch_size(self) -> int:
"""
Global batch size comprises the batch size for each cpu.
Calculated as batch_size_for_replica*replica_numbers.
Returns:
:obj:`int`: Global Batch size value.
"""
return self._global_batch_size
@global_batch_size.setter
def global_batch_size(self, global_batch_size) -> None:
r"""
Set the `_global_batch_size` property.
Args:
global_batch_size (int): Global batch size. In the case of a distributed
setup this is `batch_size on GPU * n. of GPUs`.
Return:
:py:obj:`None`
"""
assert global_batch_size > 0
self._global_batch_size = global_batch_size
[docs] @abc.abstractmethod
def call(self, context, **kwargs) -> tf.Tensor:
r"""
Execute the function, using the information provided by the context.
Args:
context (:py:class:`ashpy.contexts.Context`): The function
execution Context.
Returns:
:py:obj:`tf.Tensor`: Output Tensor.
"""
def __call__(self, context, **kwargs) -> tf.Tensor:
r"""
Invoke the function using the Context.
Args:
context (:py:class:`ashpy.contexts.Context`): The function
execution Context.
Returns:
:py:obj:`tf.Tensor`: Output Tensor.
"""
return self._weight(context.global_step) * self.call(context, **kwargs)
def __add__(self, other) -> SumExecutor:
"""Concatenate Executors together into a SumExecutor."""
if isinstance(other, SumExecutor):
other_executors = other.executors
else:
other_executors = [other]
all_executors = [self] + other_executors
return SumExecutor(all_executors)
def __mul__(self, other: Union[Callable[..., float], float, int, tf.Tensor]):
"""
Given current weight stored inside the Executor multiplies it by ``other``.
Args:
other (Either a :py:obj:`typing.Callable` or :obj:`float`,
:obj:`int`, :py:class:`tf.Tensor`):
The value (or function returning it) to use in the multiplication.
"""
assert isinstance(other, (float, int, tf.Tensor)) or callable(other)
weight = self._weight
if isinstance(other, (int, float, tf.Tensor)):
_other: Union[int, float, tf.Tensor] = other
self._weight = lambda step: weight(step) * _other
else:
__other: Callable[..., float] = other
self._weight = lambda step: weight(step) * __other(step)
return self
def __rmul__(self, other):
"""See `__mul__` method."""
return self * other
[docs]class SumExecutor(Executor):
"""
The sum executor. Executes the call of each fn and weights the losses.
Each Executor gets called (thus reducing its carried function), the results are
then summed together.
"""
[docs] def __init__(self, executors) -> None:
"""
Initialize the SumExecutor.
Args:
executors (:py:obj:`list` of [:py:class:`ashpy.executors.Executor`]): Array of
:py:obj:`ashpy.executors.Executor` to sum evaluate and sum together.
Returns:
:py:obj:`None`
"""
super().__init__()
self._executors = executors
self._global_batch_size = 1
@property
def executors(self) -> List[Executor]:
"""Return the List of Executors."""
return self._executors
@Executor.global_batch_size.setter # pylint: disable=no-member
def global_batch_size(self, global_batch_size: int) -> None:
"""Set global batch size property."""
assert global_batch_size > 0
self._global_batch_size = global_batch_size
for executor in self._executors:
executor.global_batch_size = global_batch_size
[docs] def call(self, *args, **kwargs) -> tf.Tensor:
"""
Evaluate and sum together the Executors.
Returns:
:py:classes:`tf.Tensor`: Output Tensor.
"""
result = tf.add_n([executor(*args, **kwargs) for executor in self._executors])
return result
def __add__(self, other: Union[SumExecutor, Executor]):
"""Concatenate Executors together into a SumExecutor."""
if isinstance(other, SumExecutor):
executors = other.executors
else:
executors = [other]
all_executors = self.executors + executors
return SumExecutor(all_executors)