Source code for idrlnet.torch_util

"""
conversion utils for sympy expression and torch functions.
todo: replace sampling method in GEOMETRY
"""

from sympy import lambdify, Symbol, Derivative, Function, Basic
from sympy.utilities.lambdify import implemented_function
from sympy.printing.str import StrPrinter
import torch
from idrlnet.header import DIFF_SYMBOL
from functools import reduce

__all__ = ["integral", "torch_lambdify"]


def integral_fun(x):
    if isinstance(x, torch.Tensor):
        return torch.sum(input=x, dim=0, keepdim=True) * torch.ones_like(x)
    return x


integral = implemented_function("integral", lambda x: integral_fun(x))


[docs]def torch_lambdify(r, f, *args, **kwargs): try: f = float(f) except: pass if isinstance(f, (float, int, bool)): # constant function def loop_lambda(constant): return lambda **x: torch.zeros_like(next(iter(x.items()))[1]) + constant lambdify_f = loop_lambda(f) else: lambdify_f = lambdify([k for k in r], f, [TORCH_SYMPY_PRINTER], *args, **kwargs) # lambdify_f = lambdify([k for k in r], f, *args, **kwargs) return lambdify_f
# todo: more functions TORCH_SYMPY_PRINTER = { "sin": torch.sin, "cos": torch.cos, "tan": torch.tan, "exp": torch.exp, "sqrt": torch.sqrt, "Abs": torch.abs, "tanh": torch.tanh, "DiracDelta": torch.zeros_like, "Heaviside": lambda x: torch.heaviside(x, torch.tensor([0.0])), "amin": lambda x: reduce(lambda y, z: torch.minimum(y, z), x), "amax": lambda x: reduce(lambda y, z: torch.maximum(y, z), x), "Min": lambda *x: reduce(lambda y, z: torch.minimum(y, z), x), "Max": lambda *x: reduce(lambda y, z: torch.maximum(y, z), x), "equal": lambda x, y: torch.isclose(x, y), "Xor": torch.logical_xor, "log": torch.log, "sinh": torch.sinh, "cosh": torch.cosh, "asin": torch.arcsin, "acos": torch.arccos, "atan": torch.arctan, } def _reduce_sum(x: torch.Tensor): return torch.sum(x, dim=0, keepdim=True) def _replace_derivatives(expr): while len(expr.atoms(Derivative)) > 0: deriv = expr.atoms(Derivative).pop() expr = expr.subs(deriv, Function(str(deriv))(*deriv.free_symbols)) while True: try: custom_fun = { _fun for _fun in expr.atoms(Function) if (_fun.class_key()[1] == 0) and (not _fun.class_key()[2] == "integral") }.pop() new_symbol_name = str(custom_fun) expr = expr.subs(custom_fun, Symbol(new_symbol_name)) except KeyError: break return expr class UnderlineDerivativePrinter(StrPrinter): def _print_Function(self, expr): return expr.func.__name__ def _print_Derivative(self, expr): return "".join( [str(expr.args[0].func)] + [order * (DIFF_SYMBOL + str(key)) for key, order in expr.args[1:]] ) def sstr(expr, **settings): p = UnderlineDerivativePrinter(settings) s = p.doprint(expr) return s Basic.__str__ = lambda self: sstr(self, order=None)