提交 16f98e4f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Change aesara.graph.basic.Node to pydot.Node in aesara.printing

上级 4fa5269e
......@@ -17,14 +17,7 @@ import numpy as np
from aesara.compile import Function, SharedVariable, debugmode
from aesara.compile.io import In, Out
from aesara.configdefaults import config
from aesara.graph.basic import (
Apply,
Constant,
Node,
Variable,
graph_inputs,
io_toposort,
)
from aesara.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.utils import Scratchpad
......@@ -980,13 +973,13 @@ def pydotprint(
use_color = color
if use_color is None:
nw_node = Node(aid, label=astr, shape=apply_shape)
nw_node = pd.Node(aid, label=astr, shape=apply_shape)
elif high_contrast:
nw_node = Node(
nw_node = pd.Node(
aid, label=astr, style="filled", fillcolor=use_color, shape=apply_shape
)
else:
nw_node = Node(aid, label=astr, color=use_color, shape=apply_shape)
nw_node = pd.Node(aid, label=astr, color=use_color, shape=apply_shape)
g.add_node(nw_node)
if cond_highlight:
if node in middle:
......@@ -1020,7 +1013,7 @@ def pydotprint(
color = "cyan"
if high_contrast:
g.add_node(
Node(
pd.Node(
varid,
style="filled",
fillcolor=color,
......@@ -1029,7 +1022,9 @@ def pydotprint(
)
)
else:
g.add_node(Node(varid, color=color, label=varstr, shape=var_shape))
g.add_node(
pd.Node(varid, color=color, label=varstr, shape=var_shape)
)
g.add_edge(pd.Edge(varid, aid, **param))
elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varid, aid, **param))
......@@ -1058,7 +1053,7 @@ def pydotprint(
g.add_edge(pd.Edge(aid, varid, **param))
if high_contrast:
g.add_node(
Node(
pd.Node(
varid,
style="filled",
label=varstr,
......@@ -1068,7 +1063,7 @@ def pydotprint(
)
else:
g.add_node(
Node(
pd.Node(
varid,
color=colorCodes["Output"],
label=varstr,
......@@ -1080,7 +1075,7 @@ def pydotprint(
# grey mean that output var isn't used
if high_contrast:
g.add_node(
Node(
pd.Node(
varid,
style="filled",
label=varstr,
......@@ -1089,7 +1084,9 @@ def pydotprint(
)
)
else:
g.add_node(Node(varid, label=varstr, color="grey", shape=var_shape))
g.add_node(
pd.Node(varid, label=varstr, color="grey", shape=var_shape)
)
elif var.name or not compact:
if not (not compact):
if label:
......@@ -1099,7 +1096,7 @@ def pydotprint(
label = label[: max_label_size - 3] + "..."
param["label"] = label
g.add_edge(pd.Edge(aid, varid, **param))
g.add_node(Node(varid, shape=var_shape, label=varstr))
g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
# else:
# don't add egde here as it is already added from the inputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论