Source code for idrlnet.graph

"""Define Computational graph"""

import sympy as sp
from typing import List, Dict, Union
from copy import copy
from collections import defaultdict
import networkx as nx
import matplotlib.pyplot as plt
import math
from idrlnet.variable import Variables
from idrlnet.node import Node
from idrlnet.header import logger, DIFF_SYMBOL
from idrlnet.pde import PdeNode
from idrlnet.net import NetNode

__all__ = ["ComputableNodeList", "Vertex", "VertexTaskPipeline"]
x, y = sp.symbols("x y")
ComputableNodeList = [List[Union[PdeNode, NetNode]]]


[docs]class Vertex(Node): counter = 0 def __init__(self, pre=None, next=None, node=None, ntype="c"): node = Node() if node is None else node self.__dict__ = node.__dict__.copy() self.index = type(self).counter type(self).counter += 1 self.pre = pre if pre is not None else set() self.next = next if pre is not None else set() self.ntype = ntype assert self.ntype in ("d", "c", "r") def __eq__(self, other): return self.index == other.index def __hash__(self): return self.index def __str__(self): info = ( f"index: {self.index}\n" + f"pre: {[node.index for node in self.pre]}\n" + f"next: {[node.index for node in self.next]}\n" ) return super().__str__() + info
[docs]class VertexTaskPipeline: MAX_STACK_ALLOWED = 100000 @property def evaluation_order_list(self): return self._evaluation_order_list @evaluation_order_list.setter def evaluation_order_list(self, evaluation_order_list): self._evaluation_order_list = evaluation_order_list def __init__( self, nodes: ComputableNodeList, invar: Variables, req_names: List[str] ): self.nodes = nodes self.req_names = req_names self.computable = set(invar.keys()) graph_nodes = set(Vertex(node=node) for node in nodes) req_name_dict: Dict[str, List[Vertex]] = defaultdict(list) self.G = nx.DiGraph() self.egde_data = defaultdict(set) required_stack = [] for req_name in req_names: final_graph_node = Vertex() if DIFF_SYMBOL in req_name: final_graph_node.derivatives = (req_name,) final_graph_node.inputs = tuple() else: final_graph_node.inputs = [req_name] final_graph_node.derivatives = tuple() final_graph_node.outputs = tuple() final_graph_node.name = f"<{req_name}>" final_graph_node.ntype = "r" graph_nodes.add(final_graph_node) req_name_dict[req_name].append(final_graph_node) required_stack.append(final_graph_node) final_graph_node.evaluate = lambda x: x logger.info("Constructing computation graph...") while len(req_name_dict) > 0: to_be_removed = set() to_be_added = defaultdict(list) if len(required_stack) >= self.MAX_STACK_ALLOWED: raise ValueError for req_name, current_gn in req_name_dict.items(): req_name = tuple(req_name.split(DIFF_SYMBOL)) match_score = -1 match_gn = None for gn in graph_nodes: if gn in current_gn: continue for output in gn.outputs: output = tuple(output.split(DIFF_SYMBOL)) if ( len(output) <= len(req_name) and req_name[: len(output)] == output and len(output) > match_score ): match_score = len(output) match_gn = gn for p_in in invar.keys(): p_in = tuple(p_in.split(DIFF_SYMBOL)) if ( len(p_in) <= len(req_name) and req_name[: len(p_in)] == p_in and len(p_in) > match_score ): match_score = len(p_in) match_gn = None for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: self.G.add_edge(DIFF_SYMBOL.join(p_in), sub_gn.name) if match_score <= 0: raise Exception("Can't be computed: " + DIFF_SYMBOL.join(req_name)) elif match_gn is not None: for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: logger.info( f"{sub_gn.name}.{DIFF_SYMBOL.join(req_name)} <---- {match_gn.name}" ) match_gn.next.add(sub_gn) self.egde_data[(match_gn.name, sub_gn.name)].add( DIFF_SYMBOL.join(req_name) ) required_stack.append(match_gn) for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: sub_gn.pre.add(match_gn) for p in match_gn.inputs: to_be_added[p].append(match_gn) for p in match_gn.derivatives: to_be_added[p].append(match_gn) for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: self.G.add_edge(match_gn.name, sub_gn.name) to_be_removed.add(DIFF_SYMBOL.join(req_name)) if len(to_be_removed) == 0 and len(req_name_dict) > 0: raise Exception("Can't be computed") for p in to_be_removed: req_name_dict.pop(p) self.computable.add(p) for k, v in to_be_added.items(): if k in req_name_dict: req_name_dict[k].extend(v) else: req_name_dict[k] = v evaluation_order = [] while len(required_stack) > 0: gn = required_stack.pop() if gn not in evaluation_order: evaluation_order.append(gn) self.computable = self.computable.union(set(gn.outputs)) self.evaluation_order_list = evaluation_order self._graph_node_table = {node.name: node for node in graph_nodes} for key in invar: node = Vertex() node.name = key node.outputs = (key,) node.inputs = tuple() node.ntype = "d" self._graph_node_table[key] = node logger.info("Computation graph constructed.")
[docs] def operation_order(self, invar: Variables): for node in self.evaluation_order_list: if not set(node.derivatives).issubset(invar.keys()): invar.differentiate_( independent_var=invar, required_derivatives=node.derivatives ) invar.update( node.evaluate( {**invar.subset(node.inputs), **invar.subset(node.derivatives)} ) )
[docs] def forward_pipeline( self, invar: Variables, req_names: List[str] = None ) -> Variables: if req_names is None or set(req_names).issubset(set(self.computable)): outvar = copy(invar) self.operation_order(outvar) return outvar.subset(self.req_names if req_names is None else req_names) else: logger.info("The existing graph fails. Construct a temporary graph...") return VertexTaskPipeline(self.nodes, invar, req_names).forward_pipeline( invar )
[docs] def to_json(self): pass
[docs] def display(self, filename: str = None): _, ax = plt.subplots(1, 1, figsize=(8, 8)) ax.axis("off") pos = nx.spring_layout(self.G, k=10 / (math.sqrt(self.G.order()) + 0.1)) nx.draw_networkx_nodes( self.G, pos, nodelist=list( node for node in self.G.nodes if self._graph_node_table[node].ntype == "c" ), cmap=plt.get_cmap("jet"), node_size=1300, node_color="pink", alpha=0.5, ) nx.draw_networkx_nodes( self.G, pos, nodelist=list( node for node in self.G.nodes if self._graph_node_table[node].ntype == "r" ), cmap=plt.get_cmap("jet"), node_size=1300, node_color="green", alpha=0.3, ) nx.draw_networkx_nodes( self.G, pos, nodelist=list( node for node in self.G.nodes if self._graph_node_table[node].ntype == "d" ), cmap=plt.get_cmap("jet"), node_size=1300, node_color="blue", alpha=0.3, ) nx.draw_networkx_edges( self.G, pos, edge_color="r", arrows=True, arrowsize=30, arrowstyle="-|>" ) nx.draw_networkx_labels(self.G, pos) nx.draw_networkx_edge_labels( self.G, pos, edge_labels={k: ", ".join(v) for k, v in self.egde_data.items()}, font_size=10, ) if filename is None: plt.show() else: plt.savefig(filename) plt.close()