Source code for idrlnet.node

"""Define Basic Node"""
from typing import Callable, List

from idrlnet.torch_util import torch_lambdify
from idrlnet.variable import Variables
from idrlnet.header import DIFF_SYMBOL

__all__ = ["Node"]


[docs]class Node(object): @property def inputs(self) -> List[str]: try: return self._inputs except: self._inputs = tuple() return self._inputs @inputs.setter def inputs(self, inputs: List[str]): self._inputs = inputs @property def outputs(self) -> List[str]: try: return self._outputs except: self._outputs = tuple() return self._outputs @outputs.setter def outputs(self, outputs: List[str]): self._outputs = outputs @property def derivatives(self) -> List[str]: try: return self._derivatives except: self._derivatives = [] return self._derivatives @derivatives.setter def derivatives(self, derivatives: List[str]): self._derivatives = derivatives @property def evaluate(self) -> Callable: return self._evaluate @evaluate.setter def evaluate(self, evaluate: Callable): self._evaluate = evaluate @property def name(self) -> str: try: return self._name except: self._name = "Node" + str(id(self)) return self._name @name.setter def name(self, name: str): self._name = name
[docs] @classmethod def new_node( cls, name: str = None, tf_eq: Callable = None, free_symbols: List[str] = None, *args, **kwargs ) -> "Node": node = cls() node.evaluate = LambdaTorchFun(free_symbols, tf_eq, name) node.inputs = [x for x in free_symbols if DIFF_SYMBOL not in x] node.derivatives = [x for x in free_symbols if DIFF_SYMBOL in x] node.outputs = [ name, ] node.name = name return node
def __str__(self): str_list = [ "Basic properties:\n", "name: {}\n".format(self.name), "inputs: {}\n".format(self.inputs), "derivatives: {}\n".format(self.derivatives), "outputs: {}\n".format(self.outputs), ] return "".join(str_list)
class LambdaTorchFun: def __init__(self, free_symbols, tf_eq, name): self.lambda_tf_eq = torch_lambdify(free_symbols, tf_eq) self.tf_eq = tf_eq self.name = name self.free_symbols = free_symbols def __call__(self, var: Variables): new_var = {} for key, values in var.items(): new_var[key] = values return {self.name: self.lambda_tf_eq(**new_var)}