提交 16f98e4f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Change aesara.graph.basic.Node to pydot.Node in aesara.printing

上级 4fa5269e
...@@ -17,14 +17,7 @@ import numpy as np ...@@ -17,14 +17,7 @@ import numpy as np
from aesara.compile import Function, SharedVariable, debugmode from aesara.compile import Function, SharedVariable, debugmode
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import ( from aesara.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort
Apply,
Constant,
Node,
Variable,
graph_inputs,
io_toposort,
)
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.utils import Scratchpad from aesara.graph.utils import Scratchpad
...@@ -980,13 +973,13 @@ def pydotprint( ...@@ -980,13 +973,13 @@ def pydotprint(
use_color = color use_color = color
if use_color is None: if use_color is None:
nw_node = Node(aid, label=astr, shape=apply_shape) nw_node = pd.Node(aid, label=astr, shape=apply_shape)
elif high_contrast: elif high_contrast:
nw_node = Node( nw_node = pd.Node(
aid, label=astr, style="filled", fillcolor=use_color, shape=apply_shape aid, label=astr, style="filled", fillcolor=use_color, shape=apply_shape
) )
else: else:
nw_node = Node(aid, label=astr, color=use_color, shape=apply_shape) nw_node = pd.Node(aid, label=astr, color=use_color, shape=apply_shape)
g.add_node(nw_node) g.add_node(nw_node)
if cond_highlight: if cond_highlight:
if node in middle: if node in middle:
...@@ -1020,7 +1013,7 @@ def pydotprint( ...@@ -1020,7 +1013,7 @@ def pydotprint(
color = "cyan" color = "cyan"
if high_contrast: if high_contrast:
g.add_node( g.add_node(
Node( pd.Node(
varid, varid,
style="filled", style="filled",
fillcolor=color, fillcolor=color,
...@@ -1029,7 +1022,9 @@ def pydotprint( ...@@ -1029,7 +1022,9 @@ def pydotprint(
) )
) )
else: else:
g.add_node(Node(varid, color=color, label=varstr, shape=var_shape)) g.add_node(
pd.Node(varid, color=color, label=varstr, shape=var_shape)
)
g.add_edge(pd.Edge(varid, aid, **param)) g.add_edge(pd.Edge(varid, aid, **param))
elif var.name or not compact or var in outputs: elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varid, aid, **param)) g.add_edge(pd.Edge(varid, aid, **param))
...@@ -1058,7 +1053,7 @@ def pydotprint( ...@@ -1058,7 +1053,7 @@ def pydotprint(
g.add_edge(pd.Edge(aid, varid, **param)) g.add_edge(pd.Edge(aid, varid, **param))
if high_contrast: if high_contrast:
g.add_node( g.add_node(
Node( pd.Node(
varid, varid,
style="filled", style="filled",
label=varstr, label=varstr,
...@@ -1068,7 +1063,7 @@ def pydotprint( ...@@ -1068,7 +1063,7 @@ def pydotprint(
) )
else: else:
g.add_node( g.add_node(
Node( pd.Node(
varid, varid,
color=colorCodes["Output"], color=colorCodes["Output"],
label=varstr, label=varstr,
...@@ -1080,7 +1075,7 @@ def pydotprint( ...@@ -1080,7 +1075,7 @@ def pydotprint(
# grey mean that output var isn't used # grey mean that output var isn't used
if high_contrast: if high_contrast:
g.add_node( g.add_node(
Node( pd.Node(
varid, varid,
style="filled", style="filled",
label=varstr, label=varstr,
...@@ -1089,7 +1084,9 @@ def pydotprint( ...@@ -1089,7 +1084,9 @@ def pydotprint(
) )
) )
else: else:
g.add_node(Node(varid, label=varstr, color="grey", shape=var_shape)) g.add_node(
pd.Node(varid, label=varstr, color="grey", shape=var_shape)
)
elif var.name or not compact: elif var.name or not compact:
if not (not compact): if not (not compact):
if label: if label:
...@@ -1099,7 +1096,7 @@ def pydotprint( ...@@ -1099,7 +1096,7 @@ def pydotprint(
label = label[: max_label_size - 3] + "..." label = label[: max_label_size - 3] + "..."
param["label"] = label param["label"] = label
g.add_edge(pd.Edge(aid, varid, **param)) g.add_edge(pd.Edge(aid, varid, **param))
g.add_node(Node(varid, shape=var_shape, label=varstr)) g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
# else: # else:
# don't add egde here as it is already added from the inputs. # don't add egde here as it is already added from the inputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论