Source code for finam_graph.diagram

"""Main module for graph diagram drawer"""
import math

import matplotlib.pyplot as plt
import numpy as np
from finam.interfaces import IComponent, ITimeComponent
from matplotlib import patches
from matplotlib.axes import Axes
from matplotlib.backend_bases import MouseButton
from matplotlib.path import Path

from finam_graph.graph import Graph


[docs] class GraphSizes: """Graph sizing properties Parameters ---------- grid_size : (int, int) Size of grid cells for alignment component_size : (int, int) Size of component boxes adapter_size : (int, int) Size of adapter boxes margin : int Margin around all boxes comp_slot_size : (int, int) Input and output slot size for components adap_slot_size : (int, int) Input and output slot size for adapters curve_size : int Connection curve "radius" (control point distance) """ def __init__( self, grid_size=(160, 100), component_size=(80, 60), adapter_size=(80, 30), margin=50, comp_slot_size=(30, 14), adap_slot_size=(10, 10), curve_size=20, ): self.grid_size = grid_size self.component_size = component_size self.adapter_size = adapter_size self.margin = margin self.comp_slot_size = comp_slot_size self.adap_slot_size = adap_slot_size self.curve_size = curve_size
[docs] class GraphColors: """Graph coloring properties Parameters ---------- comp_color : str :class:`finam.Component` color time_comp_color : str :class:`finam.TimeComponent` color selected_comp_color : str Component color for selection adapter_color : str :class:`finam.Adapter` color selected_adapter_color : str Adapter color for selection """ def __init__( self, comp_color="lightgreen", time_comp_color="lightblue", selected_comp_color="blue", adapter_color="orange", selected_adapter_color="red", ): self.comp_color = comp_color self.time_comp_color = time_comp_color self.selected_comp_color = selected_comp_color self.adapter_color = adapter_color self.selected_adapter_color = selected_adapter_color
[docs] class GraphDiagram: """Diagram drawer. Examples -------- .. code-block:: Python composition = Composition([comp_a, comp_b]) composition.initialize() comp_a.outputs["Out"] >> comp_b.inputs["In"] diagram = GraphDiagram() diagram.draw(composition, save_path="graph.svg") Parameters ---------- sizes : GraphSizes Graph sizing properties object :class:`.GraphSizes`. colors : GraphColors Graph coloring properties object :class:`.GraphColors`. corner_radius : int Radius for rounded corners max_label_length : int Maximum number of characters in component and adapter labels max_slot_label_length : int Maximum number of characters in input and output slot labels """ def __init__( self, sizes=GraphSizes(), colors=GraphColors(), corner_radius=5, max_label_length=12, max_slot_label_length=6, ): self.sizes = sizes self.colors = colors self.corner_radius = corner_radius self.max_label_length = max_label_length self.max_slot_label_length = max_slot_label_length self.component_offset = (sizes.grid_size[0] - sizes.component_size[0]) / 2, ( sizes.grid_size[1] - sizes.component_size[1] ) / 2 self.adapter_offset = (sizes.grid_size[0] - sizes.adapter_size[0]) / 2, ( sizes.grid_size[1] - sizes.adapter_size[1] ) / 2 self.selected_cell = None self.show_grid = False
[docs] def draw( self, composition, details=2, excluded=None, positions=None, labels=None, colors=None, show=True, block=True, save_path=None, max_iterations=25000, seed=None, ): """ Draw a graph diagram. Parameters ---------- composition : Composition The :class:`finam.Composition` to draw a graph diagram for excluded : list or set, optional List of excluded components. Default: None details : int, optional Level of details of the graph plot. * 0: Simple graph without slots and adapters * 1: Detailed graph, with collapsed adapters * 2: Full detailed graph, with adapters Defaults to 2. positions : dict, optional Dictionary of grid cell position tuples per component/adapter. Default: None (optimized) labels : dict, optional Dictionary of label overrides for components, adapters and input/output slots. Default: None colors : dict, optional Dictionary of component/adapter color overrides. Default: None show : bool, optional Whether to show the diagram. Default: True block : bool, optional Should the diagram be shown in blocking mode? Default: True save_path : pathlike, optional Path to save image file. Default: None (i.e. don't save) max_iterations : int, optional Maximum iterations for optimizing node placement. Default: 25000 seed : int, optional Random seed for the optimizer. Default: None """ colors = colors or {} labels = labels or {} excluded = set(excluded) if excluded is not None else set() show_adapters = details > 1 simple = details < 1 rng = ( np.random.default_rng() if seed is None else np.random.default_rng(seed=seed) ) graph = Graph(composition, excluded) if positions is None: positions = _optimize_positions( graph, rng, simple, show_adapters, max_iterations ) figure, ax = plt.subplots(figsize=(12, 6)) if figure.canvas.manager is not None: figure.canvas.manager.set_window_title( "Graph - SPACE for grid, click to re-arrange" ) ax.axis("off") ax.set_aspect("equal") figure.subplots_adjust(left=0, right=1, top=1, bottom=0) self._repaint(graph, positions, labels, colors, simple, show_adapters, ax) if save_path is not None: plt.savefig(save_path) if show: self._show( graph, positions, labels, colors, simple, show_adapters, ax, figure, block, )
def _show( self, graph, positions, labels, colors, simple, show_adapters, ax, figure, block ): def onclick(event): if event.xdata is None: return if event.button == MouseButton.RIGHT: self.selected_cell = None self._repaint( graph, positions, labels, colors, simple, show_adapters, ax ) return xdata, ydata = event.xdata, event.ydata cell = int(math.floor(xdata / self.sizes.grid_size[0])), int( math.floor(ydata / self.sizes.grid_size[1]) ) if self.selected_cell is None: for k, v in positions.items(): if v == cell: self.selected_cell = k self._repaint( graph, positions, labels, colors, simple, show_adapters, ax, ) break else: positions[self.selected_cell] = cell self.selected_cell = None self._repaint( graph, positions, labels, colors, simple, show_adapters, ax ) def on_press(event): if event.key == " ": self.show_grid = not self.show_grid self._repaint( graph, positions, labels, colors, simple, show_adapters, ax ) def on_close(_event): plt.close(figure) plt.ioff() _cid = figure.canvas.mpl_connect("button_press_event", onclick) _cid = figure.canvas.mpl_connect("key_press_event", on_press) _cid = figure.canvas.mpl_connect("close_event", on_close) plt.ion() plt.show(block=block) def _repaint( self, graph, positions, labels, colors, simple: bool, show_adapters: bool, axes: Axes, ): while bool(axes.patches): axes.patches[0].remove() while bool(axes.texts): axes.texts[0].remove() x_bounds, y_bounds = _calc_bounds(positions) x_lim, y_lim = self._calc_limits(x_bounds, y_bounds) axes.set_xlim(*x_lim) axes.set_ylim(*y_lim) if self.show_grid: self._draw_grid(x_bounds, y_bounds, axes) comp_patches = {} for comp in graph.components: comp_patches[comp] = self._draw_component( comp, positions[comp], labels.get(comp), colors.get(comp), simple, labels, axes, ) if show_adapters: for ad in graph.adapters: self._draw_adapter( ad, positions[ad], labels.get(ad), colors.get(ad), axes ) if simple: self._draw_edges_simple(graph.simple_edges, positions, comp_patches, axes) return edges = graph.edges if show_adapters else graph.direct_edges for edge in edges: self._draw_edge(edge, positions, show_adapters, axes) def _calc_limits(self, x_min_max, y_min_max): x_lim = ( x_min_max[0] * self.sizes.grid_size[0] - self.sizes.margin, (x_min_max[1] + 1) * self.sizes.grid_size[0] + self.sizes.margin, ) y_lim = ( y_min_max[0] * self.sizes.grid_size[1] - self.sizes.margin, (y_min_max[1] + 1) * self.sizes.grid_size[1] + self.sizes.margin, ) return x_lim, y_lim def _draw_grid(self, x_bounds, y_bounds, axes: Axes): for i in range(x_bounds[0] - 1, x_bounds[1] + 2): for j in range(y_bounds[0] - 1, y_bounds[1] + 2): rect = patches.Rectangle( (i * self.sizes.grid_size[0], j * self.sizes.grid_size[1]), *self.sizes.grid_size, linewidth=1, edgecolor="lightgrey", facecolor="none", ) axes.add_patch(rect) def _draw_edges_simple(self, simple_edges, positions, comp_patches, axes: Axes): drawn = set() for source, target in simple_edges: if (source, target) in drawn: continue bidir = (target, source) in simple_edges if bidir: drawn.add((target, source)) src_pos = self._comp_pos(source, positions[source]) trg_pos = self._comp_pos(target, positions[target]) src_pos = ( src_pos[0] + self.sizes.component_size[0] / 2, src_pos[1] + self.sizes.component_size[1] / 2, ) trg_pos = ( trg_pos[0] + self.sizes.component_size[0] / 2, trg_pos[1] + self.sizes.component_size[1] / 2, ) style = "<|-|>" if bidir else "-|>" arr = patches.ConnectionPatch( src_pos, trg_pos, "data", "data", patchA=comp_patches[source], patchB=comp_patches[target], arrowstyle=style, mutation_scale=20, fc="w", ) axes.add_patch(arr) def _draw_edge(self, edge, positions, show_adapters: bool, axes: Axes): src_pos = self._comp_pos(edge.source, positions[edge.source]) trg_pos = self._comp_pos(edge.target, positions[edge.target]) if isinstance(edge.source, IComponent): out_idx = list(edge.source.outputs.keys()).index(edge.out_name) out_size = self.sizes.comp_slot_size else: out_idx = 0 out_size = self.sizes.adap_slot_size if isinstance(edge.target, IComponent): in_idx = list(edge.target.inputs.keys()).index(edge.in_name) in_size = self.sizes.comp_slot_size else: in_idx = 0 in_size = self.sizes.adap_slot_size out_off = self._output_pos(edge.source, out_idx) in_off = self._input_pos(edge.target, in_idx) p1 = ( src_pos[0] + out_off[0] + out_size[0], src_pos[1] + out_off[1] + out_size[1] / 2, ) p4 = trg_pos[0] + in_off[0], trg_pos[1] + in_off[1] + in_size[1] / 2 dx = abs(p4[0] - p1[0]) curve_sz = max(self.sizes.curve_size, dx / 2) p2 = p1[0] + curve_sz, p1[1] p3 = p4[0] - curve_sz, p4[1] axes.add_patch( patches.PathPatch( Path( [p1, p2, p3, p4], [Path.MOVETO, Path.CURVE4, Path.CURVE4, Path.CURVE4], ), fc="none", ) ) if edge.num_adapters > 0 and not show_adapters: pc = (p1[0] + p4[0]) / 2, (p1[1] + p4[1]) / 2 axes.add_patch( patches.Rectangle( (pc[0] - 4, pc[1] - 4), 8, 8, linewidth=1, edgecolor="k", facecolor=self.colors.adapter_color, ) ) axes.text( *pc, str(edge.num_adapters), ha="center", va="center", size=6, ) def _draw_component(self, comp, position, label, color, simple, labels, axes: Axes): name = label or comp.name xll, yll = self._comp_pos(comp, position) rect = patches.FancyBboxPatch( (xll, yll), *self.sizes.component_size, boxstyle=f"round,rounding_size={self.corner_radius}", linewidth=1, edgecolor="k", facecolor=self.colors.selected_comp_color if self.selected_cell == comp else color or ( self.colors.time_comp_color if isinstance(comp, ITimeComponent) else self.colors.comp_color ), ) axes.add_patch(rect) if not simple: self._draw_slots(comp, labels, xll, yll, axes) axes.text( xll + self.sizes.component_size[0] / 2, yll + self.sizes.component_size[1] / 2, _shorten_str(name.replace("Component", "Co"), self.max_label_length), ha="center", va="center", size=8, ) return rect def _draw_slots(self, comp, labels, xll, yll, axes): if len(comp.inputs) > 0: for i, (n, inp) in enumerate(comp.inputs.items()): in_name = labels.get(inp, n) xlli, ylli = self._input_pos(comp, i) inp_rect = patches.Rectangle( (xll + xlli, yll + ylli), *self.sizes.comp_slot_size, linewidth=1, edgecolor="k", facecolor="lightgrey", ) axes.add_patch(inp_rect) axes.text( xll + xlli + 2, yll + ylli + self.sizes.comp_slot_size[1] / 2, _shorten_str(in_name, self.max_slot_label_length), ha="left", va="center", size=7, ) if len(comp.outputs) > 0: for i, (n, out) in enumerate(comp.outputs.items()): out_name = labels.get(out, n) xllo, yllo = self._output_pos(comp, i) out_rect = patches.Rectangle( (xll + xllo, yll + yllo), *self.sizes.comp_slot_size, linewidth=1, edgecolor="k", facecolor="white", ) axes.add_patch(out_rect) axes.text( xll + xllo + 2, yll + yllo + self.sizes.comp_slot_size[1] / 2, _shorten_str(out_name, self.max_slot_label_length), ha="left", va="center", size=7, ) def _draw_adapter(self, comp, position, label, color, axes: Axes): name = label or comp.name xll, yll = self._comp_pos(comp, position) rect = patches.FancyBboxPatch( (xll, yll), *self.sizes.adapter_size, boxstyle=f"round, pad=0, rounding_size={self.corner_radius}", linewidth=1, edgecolor="k", facecolor=self.colors.selected_adapter_color if self.selected_cell == comp else color or self.colors.adapter_color, ) xlli, ylli = self._input_pos(comp, 0) inp = patches.Rectangle( (xll + xlli, yll + ylli), *self.sizes.adap_slot_size, linewidth=1, edgecolor="k", facecolor="lightgrey", ) xllo, yllo = self._output_pos(comp, 0) out = patches.Rectangle( (xll + xllo, yll + yllo), *self.sizes.adap_slot_size, linewidth=1, edgecolor="k", facecolor="white", ) axes.add_patch(rect) axes.add_patch(inp) axes.add_patch(out) axes.text( xll + self.sizes.adapter_size[0] / 2, yll + self.sizes.adapter_size[1] / 2, _shorten_str(name.replace("Adapter", "Ad."), self.max_label_length), ha="center", va="center", size=8, ) def _comp_pos(self, comp_or_ada, pos): if isinstance(comp_or_ada, IComponent): return ( pos[0] * self.sizes.grid_size[0] + self.component_offset[0], pos[1] * self.sizes.grid_size[1] + self.component_offset[1], ) return ( pos[0] * self.sizes.grid_size[0] + self.adapter_offset[0], pos[1] * self.sizes.grid_size[1] + self.adapter_offset[1], ) def _input_pos(self, comp_or_ada, idx): if isinstance(comp_or_ada, IComponent): cnt = len(comp_or_ada.inputs) inv_idx = cnt - 1 - idx in_sp = self.sizes.component_size[1] / cnt return ( -self.sizes.comp_slot_size[0], in_sp / 2 + in_sp * inv_idx - self.sizes.comp_slot_size[1] / 2, ) return ( -self.sizes.adap_slot_size[0], self.sizes.adapter_size[1] / 2 - self.sizes.adap_slot_size[1] / 2, ) def _output_pos(self, comp_or_ada, idx): if isinstance(comp_or_ada, IComponent): cnt = len(comp_or_ada.outputs) inv_idx = cnt - 1 - idx out_sp = self.sizes.component_size[1] / cnt return ( self.sizes.component_size[0], out_sp / 2 + out_sp * inv_idx - self.sizes.comp_slot_size[1] / 2, ) return ( self.sizes.adapter_size[0], self.sizes.adapter_size[1] / 2 - self.sizes.adap_slot_size[1] / 2, )
def _calc_bounds(positions): x_min, y_min = 99999, 99999 x_max, y_max = -99999, -99999 for _c, pos in positions.items(): if pos[0] < x_min: x_min = pos[0] if pos[1] < y_min: y_min = pos[1] if pos[0] > x_max: x_max = pos[0] if pos[1] > y_max: y_max = pos[1] return (x_min, x_max), (y_min, y_max) def _shorten_str(s, max_length): if len(s) > max_length: return s[0 : (max(1, max_length - 1))] return s def _optimize_positions( graph: Graph, rng, simple: bool, show_adapters: bool, max_iterations: int ): length = len(graph.components) if show_adapters: length += len(graph.adapters) size = math.ceil(math.sqrt(length)) * 3 grid = np.ndarray((size, size), dtype=object) all_mods = ( set.union(graph.components, graph.adapters) if show_adapters else graph.components ) pos = _random_initial_positions(all_mods, grid, size, rng) nodes = list(pos.keys()) nodes.sort(key=lambda co: co.__class__.__name__) return _do_optimize_positions( graph, nodes, simple, show_adapters, pos, grid, size, max_iterations, rng ) def _do_optimize_positions( graph, nodes, simple, show_adapters, pos, grid, size, max_iterations, rng ): print("Optimizing graph layout...") score = _rate_positions( pos, graph.edges if show_adapters else graph.direct_edges, simple ) last_improvement = 0 i = -1 for i in range(max_iterations): pos_new = dict(pos) grid_new = grid.copy() for _j in range(rng.integers(1, 5, 1)[0]): node = rng.choice(nodes) x, y = rng.integers(0, size, 2) node_here = grid_new[x, y] if node_here == node: continue if node_here is None: grid_new[pos_new[node]] = None grid_new[x, y] = node pos_new[node] = (x, y) else: grid_new[pos_new[node]] = node_here grid_new[x, y] = node pos_new[node_here] = pos_new[node] pos_new[node] = (x, y) score_new = _rate_positions( pos_new, graph.edges if show_adapters else graph.direct_edges, simple ) if score_new <= score: if score_new < score: last_improvement = i pos = pos_new grid = grid_new score = score_new if i > 2500 and i > 4 * last_improvement: break print(f"Done ({i + 1} iterations, score {score})") return pos def _random_initial_positions(all_mods, grid, size, rng): pos = {} for c in all_mods: while True: x, y = rng.integers(0, size, 2) if grid[x, y] is None: grid[x, y] = c break pos[c] = x, y return pos def _rate_positions(pos, edges, simple: bool): score = 0.0 if simple: for e in edges: p1 = pos[e.source] p2 = pos[e.target] dist = abs(p2[0] - p1[0]) + abs(p2[1] - p1[1]) score += dist else: for e in edges: p1 = pos[e.source] p2 = pos[e.target] dx = p2[0] - (p1[0] + 1) sc_rev_same_row = 0 sc_x = dx if dx < 0: if p2[1] == p1[1]: sc_rev_same_row = 5 if dx < -1: sc_x *= 2 dist = abs(sc_x) + max(0, abs(p2[1] - p1[1]) - 0.5) + sc_rev_same_row score += dist return score**2