提交 6c89a79f authored 作者: Frederic's avatar Frederic

use pydotprint label on variable to allow multiple node with same visible string

上级 b6dc9a8a
......@@ -741,11 +741,12 @@ def pydotprint(fct, outfile=None,
right = list(right)
var_str = {}
var_id = {}
all_strings = set()
def var_name(var):
if var in var_str:
return var_str[var]
return var_str[var], var_id[var]
if var.name is not None:
if var_with_name_simple:
......@@ -768,26 +769,14 @@ def pydotprint(fct, outfile=None,
# a var id is needed as otherwise var with the same type will be
# merged in the graph.
varstr = str(var.type)
if (varstr in all_strings) or with_ids:
idx = ' id=' + str(len(var_str))
if len(varstr) + len(idx) > max_label_size:
varstr = varstr[:max_label_size - 3 - len(idx)] + idx + '...'
else:
varstr = varstr + idx
elif len(varstr) > max_label_size:
if len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...'
idx = 1
while varstr in all_strings:
idx += 1
suffix = ' id=' + str(idx)
varstr = (varstr[:max_label_size - 3 - len(suffix)] +
'...' +
suffix)
var_str[var] = varstr
var_id[var] = str(id(var))
all_strings.add(varstr)
return varstr
return varstr, var_id[var]
apply_name_cache = {}
def apply_name(node):
......@@ -879,7 +868,7 @@ def pydotprint(fct, outfile=None,
c2.add_node(nw_node)
for idx, var in enumerate(node.inputs):
varstr = var_name(var)
varstr, varid = var_name(var)
label = ""
if len(node.inputs) > 1:
label = str(idx)
......@@ -899,15 +888,19 @@ def pydotprint(fct, outfile=None,
# Mixing blue and green give cyan! (input and output var)
color = "cyan"
if high_contrast:
g.add_node(pd.Node(varstr,
g.add_node(pd.Node(varid,
style='filled',
fillcolor=color,
label=varstr,
shape=var_shape))
else:
g.add_node(pd.Node(varstr, color=color, shape=var_shape))
g.add_edge(pd.Edge(varstr, astr, **param))
g.add_node(pd.Node(varid,
color=color,
label=varstr,
shape=var_shape))
g.add_edge(pd.Edge(varid, astr, **param))
elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varstr, astr, **param))
g.add_edge(pd.Edge(varid, astr, **param))
else:
# no name, so we don't make a var ellipse
if label:
......@@ -919,7 +912,7 @@ def pydotprint(fct, outfile=None,
g.add_edge(pd.Edge(apply_name(var.owner), astr, **param))
for idx, var in enumerate(node.outputs):
varstr = var_name(var)
varstr, varid = var_name(var)
out = var in outputs
label = ""
if len(node.outputs) > 1:
......@@ -930,20 +923,25 @@ def pydotprint(fct, outfile=None,
if label:
param['label'] = label
if out or var in input_update:
g.add_edge(pd.Edge(astr, varstr, **param))
g.add_edge(pd.Edge(astr, varid, **param))
if high_contrast:
g.add_node(pd.Node(varstr, style='filled',
g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='blue', shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='blue', shape=var_shape))
g.add_node(pd.Node(varid, color='blue',
label=varstr,
shape=var_shape))
elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, **param))
g.add_edge(pd.Edge(astr, varid, **param))
# grey mean that output var isn't used
if high_contrast:
g.add_node(pd.Node(varstr, style='filled',
g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='grey', shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape))
g.add_node(pd.Node(varid, label=varstr,
color='grey', shape=var_shape))
elif var.name or not compact:
if not(not compact):
if label:
......@@ -952,16 +950,16 @@ def pydotprint(fct, outfile=None,
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param['label'] = label
g.add_edge(pd.Edge(astr, varstr, **param))
g.add_node(pd.Node(varstr, shape=var_shape))
g.add_edge(pd.Edge(astr, varid, **param))
g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
# else:
# don't add egde here as it is already added from the inputs.
# The var that represent updates, must be linked to the input var.
for sha, up in input_update.items():
shastr = var_name(sha)
upstr = var_name(up)
g.add_edge(pd.Edge(shastr, upstr, label="UPDATE", color="blue"))
_, shaid = var_name(sha)
_, upid = var_name(up)
g.add_edge(pd.Edge(shaid, upid, label="UPDATE", color="blue"))
if cond_highlight:
g.add_subgraph(c1)
......@@ -1069,7 +1067,7 @@ def pydotprint_variables(vars,
if len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...'
var_str[var] = varstr
return varstr
return varstr, varlabel
def apply_name(node):
name = str(node.op).replace(':', '_')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论