eagerpy icon indicating copy to clipboard operation
eagerpy copied to clipboard

Have a decorator to wrap universal functions ?

Open eserie opened this issue 4 years ago • 6 comments

In order to simplify the writting of universal functions it could be great to have a decorator function which hide the technical part of the code (convertion of input and output of the wrapped function/method). For example, the code:

def my_universal_function(a, b, c):
    # Convert all inputs to EagerPy tensors
    a, b, c = ep.astensors(a, b, c)

    # performs some computations
    result = (a + b * c).square()

    # and return a native tensor
    return result.raw

would become:

@eager_function
def my_universal_function(a, b, c):
    return (a + b * c).square()

In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.

I wrote a prototype of such a decorator function. It should not work on any type of arguments and so its usage would require that the wrapped function has a rather "simple" signature (with args and kwargs constituted of tensors or nested containers with tensors on leaves: dict, list, tuple or namedtuple like containers).

Would you consider to have this feature in eagerpy?

eserie avatar Apr 18 '21 06:04 eserie

Would you consider to have this feature in eagerpy?

Yes, a nice generic decorator that can handle arbitrary number of arguments (and return values) would be great. I am pretty sure I thought about this before, but I cannot recall why I didn't do it.

In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.

Have you seen the ep.astensor_ and ep.astensors_ functions (with the underscore)? They already do exactly that: https://eagerpy.jonasrauber.de/guide/generic-functions.html (see the examples at the end).

jonasrauber avatar Apr 19 '21 08:04 jonasrauber

Thanks for your response! No, unfortunately I didn’t have seen the functions astensor_ and astensors_ (only astensor). It should definitively be a good starting point! I think I have a POC for a version of that function which could manage more general formats for inputs/outputs. I can make a try for integration in eagerpy and propose a PR in coming days if you are ok.

eserie avatar Apr 19 '21 13:04 eserie

I show bellow a first POC that I wrote for the wrapper function eager_function (which is working). At this stage, it seems not totally trivial to me how to integrate it the code base. I would appreciate to have a first feedback from this code in order to know if we can go further in this direction.

import numbers
from collections import defaultdict
from functools import wraps
from typing import Any

import eagerpy as ep


def _tuple_as(template, data):
    data = list(data)
    try:
        # list, tuple case
        return type(template)(data)
    except TypeError:
        # named tuple case
        return type(template)(*data)


def _dict_as(template, data):
    """Create dictionary like data structure from template object.
    Parameters
    ----------
    template
        objecti used as template
    data
        data used to fill the created object.
    """
    if isinstance(template, defaultdict):
        return type(template)(template.default_factory, data)
    return type(template)(data)


def as_eager_tensors(data: Any) -> (Any, bool):
    return as_eager_tensors_(data)[0]


def as_eager_tensors_(data: Any) -> (Any, bool):
    """Convert to eagerpy tensors.
    Parameters
    ----------
    data : (tuple, list, dict, namedtuple, defaultdict)
        data structure to convert

    Returns
    -------
    unwrap : bool
        if True, it means that the tensors have been converted
        to eagerpy tensors.

    """
    if isinstance(data, dict):
        # dict, defaultdict
        if not data:
            return data, None
        keys, res_values, unwrap_values = zip(
            *[(dim,) + as_eager_tensors_(var) for dim, var in data.items()]
        )
        unwrap = True in unwrap_values
        return _dict_as(data, dict(zip(keys, res_values))), unwrap
    elif isinstance(data, (list, tuple)):
        if not data:
            return data, None

        res_values, unwrap_values = zip(*[as_eager_tensors_(var) for var in data])
        unwrap = True in unwrap_values
        try:
            res = type(data)(res_values)
        except TypeError:
            res = type(data)(*res_values)
        return res, unwrap

    elif isinstance(data, ep.Tensor):
        return data, False
    elif isinstance(data, np.datetime64):
        # datetime not managed by ep.tensors
        return data, False
    elif isinstance(data, numbers.Number):
        return data, False
    return ep.astensor(data), True


def as_raw_tensors(data):
    """Convert from eager tensors to raw tensors.

    Parameters
    ----------
    data
        data to convert

    """
    if isinstance(data, dict):
        return _dict_as(data, {dim: as_raw_tensors(var) for dim, var in data.items()})
    elif isinstance(data, (list, tuple)):
        return _tuple_as(data, (as_raw_tensors(var) for var in data))

    if isinstance(data, ep.Tensor):
        return data.raw
    else:
        return data


def restore_tensor_type(data: Any, unwrap: bool) -> Any:
    if unwrap:
        return as_raw_tensors(data)
    else:
        return data


def eager_function(func):
    @wraps(func)
    def eager_func(*args, **kwargs):
        self = None
        if len(func.__qualname__.split(".")) > 1:
            args = list(args)
            self = args.pop(0)
        args, args_unwrap = as_eager_tensors_(args)
        kwargs, kwargs_unwrap = as_eager_tensors_(kwargs)
        unwrap = args_unwrap or kwargs_unwrap
        if self:
            args = [self] + args
        result = func(*args, **kwargs)
        return restore_tensor_type(result, unwrap)

    return eager_func

eserie avatar Apr 20 '21 22:04 eserie

Another possibility could be to use pytrees implemented in Jax. This should permit to handle more data structures and also to rely on the existing astensors_ implementation using flatten version of the inputs and outputs. However this would create a hard dependency with Jax in eagerpy while currently it's maybe optional.

eserie avatar Apr 22 '21 07:04 eserie

I propose an implementation based on pytrees in https://github.com/jonasrauber/eagerpy/pull/41. This way to proceed imply few changes like no more register JAXTensor as a pytree datastructure and instead use jax pytree utils for more general datastructures manipulations in eagerpy. The new introduced datastructure convertion functions permit to factorize a bit the method JAXTensor._value_and_grad_fn (for which the initial registration of JAXTensor was tailored)

eserie avatar Apr 23 '21 23:04 eserie

In fact, I think it's not a good idea to not register JAXTensor in pytrees, it should prevent to have compatibility with jax functionalities. I will restore that in an update of the review.

eserie avatar Apr 25 '21 11:04 eserie