Source code for idrlnet.geo_utils.sympy_np

"""Convert sympy expression to np functions
todo: converges to torch_util

"""

import numpy as np
from sympy import lambdify
from typing import Iterable
from functools import reduce
import collections
from sympy import Max, Min, Mul

__all__ = ["lambdify_np"]


class WrapSympy:
    is_sympy = True

    @staticmethod
    def _wrapper_guide(args):
        func_1 = args[0]
        func_2 = args[1]
        cond_1 = isinstance(func_1, WrapSympy) and not func_1.is_sympy
        cond_2 = isinstance(func_2, WrapSympy) and not func_2.is_sympy
        cond_3 = (not isinstance(func_1, WrapSympy)) and isinstance(
            func_1, collections.Callable
        )
        cond_4 = (not isinstance(func_2, WrapSympy)) and isinstance(
            func_2, collections.Callable
        )
        return cond_1 or cond_2 or cond_3 or cond_4, func_1, func_2


class WrapMax(Max, WrapSympy):
    def __new__(cls, *args, **kwargs):
        cond, func_1, func_2 = WrapMax._wrapper_guide(args)
        if cond:
            a = object.__new__(cls)
            a.f = func_1
            a.g = func_2
            a.is_sympy = False
        else:
            a = Max.__new__(cls, *args, **kwargs)
            if isinstance(a, WrapSympy):
                a.is_sympy = True
        return a

    def __call__(self, **x):
        if not self.is_sympy:
            f = lambdify_np(self.f, x.keys())
            g = lambdify_np(self.g, x.keys())
            return np.maximum(f(**x), g(**x))
        else:
            f = lambdify_np(self, x.keys())
            return f(**x)


class WrapMul(Mul, WrapSympy):
    def __new__(cls, *args, **kwargs):
        cond, func_1, func_2 = WrapMul._wrapper_guide(args)
        if cond:
            a = object.__new__(cls)
            a.f = func_1
            a.g = func_2
            a.is_sympy = False
        else:
            a = Mul.__new__(cls, *args, **kwargs)
            if isinstance(a, WrapSympy):
                a.is_sympy = True
        return a

    def __call__(self, **x):
        if not self.is_sympy:
            f = lambdify_np(self.f, x.keys())
            g = lambdify_np(self.g, x.keys())
            return f(**x) * g(**x)
        else:
            f = lambdify_np(self, x.keys())
            return f(**x)


class WrapMin(Min, WrapSympy):
    def __new__(cls, *args, **kwargs):
        cond, func_1, func_2 = WrapMin._wrapper_guide(args)
        if cond:
            a = object.__new__(cls)
            a.f = func_1
            a.g = func_2
            a.is_sympy = False
        else:
            a = Min.__new__(cls, *args, **kwargs)
            if isinstance(a, WrapSympy):
                a.is_sympy = True
        return a

    def __call__(self, **x):
        if not self.is_sympy:
            f = lambdify_np(self.f, x.keys())
            g = lambdify_np(self.g, x.keys())
            return np.minimum(f(**x), g(**x))
        else:
            f = lambdify_np(self, x.keys())
            return f(**x)


def _try_float(fn):
    try:
        fn = float(fn)
    except ValueError:
        pass
    except TypeError:
        pass
    return fn


def _constant_bool(boolean: bool):
    def fn(**x):
        return (
            np.ones_like(next(iter(x.items()))[1], dtype=bool)
            if boolean
            else np.zeros_like(next(iter(x.items()))[1], dtype=bool)
        )

    return fn


def _constant_float(f):
    def fn(**x):
        return np.ones_like(next(iter(x.items()))[1]) * f

    return fn


[docs]def lambdify_np(f, r: Iterable): if isinstance(r, dict): r = r.keys() if isinstance(f, WrapSympy) and f.is_sympy: lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, "numpy"]) lambdify_f.input_keys = [k for k in r] return lambdify_f if isinstance(f, WrapSympy) and not f.is_sympy: return f if isinstance(f, collections.Callable): return f if isinstance(f, bool): return _constant_bool(f) f = _try_float(f) if isinstance(f, float): return _constant_float(f) else: lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, "numpy"]) lambdify_f.input_keys = [k for k in r] return lambdify_f
PLACEHOLDER = { "amin": lambda x: reduce(lambda y, z: np.minimum(y, z), x), "amax": lambda x: reduce(lambda y, z: np.maximum(y, z), x), "Min": lambda *x: reduce(lambda y, z: np.minimum(y, z), x), "Max": lambda *x: reduce(lambda y, z: np.maximum(y, z), x), "Heaviside": lambda x: np.heaviside(x, 0), "equal": lambda x, y: np.isclose(x, y), "Xor": np.logical_xor, "cos": np.cos, "sin": np.sin, "tan": np.tan, "exp": np.exp, "sqrt": np.sqrt, "log": np.log, "sinh": np.sinh, "cosh": np.cosh, "tanh": np.tanh, "asin": np.arcsin, "acos": np.arccos, "atan": np.arctan, "Abs": np.abs, "DiracDelta": np.zeros_like, }