提交 6c89a79f authored 作者: Frederic's avatar Frederic

use pydotprint label on variable to allow multiple node with same visible string

上级 b6dc9a8a
...@@ -741,11 +741,12 @@ def pydotprint(fct, outfile=None, ...@@ -741,11 +741,12 @@ def pydotprint(fct, outfile=None,
right = list(right) right = list(right)
var_str = {} var_str = {}
var_id = {}
all_strings = set() all_strings = set()
def var_name(var): def var_name(var):
if var in var_str: if var in var_str:
return var_str[var] return var_str[var], var_id[var]
if var.name is not None: if var.name is not None:
if var_with_name_simple: if var_with_name_simple:
...@@ -768,26 +769,14 @@ def pydotprint(fct, outfile=None, ...@@ -768,26 +769,14 @@ def pydotprint(fct, outfile=None,
# a var id is needed as otherwise var with the same type will be # a var id is needed as otherwise var with the same type will be
# merged in the graph. # merged in the graph.
varstr = str(var.type) varstr = str(var.type)
if (varstr in all_strings) or with_ids: if len(varstr) > max_label_size:
idx = ' id=' + str(len(var_str))
if len(varstr) + len(idx) > max_label_size:
varstr = varstr[:max_label_size - 3 - len(idx)] + idx + '...'
else:
varstr = varstr + idx
elif len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...' varstr = varstr[:max_label_size - 3] + '...'
idx = 1
while varstr in all_strings:
idx += 1
suffix = ' id=' + str(idx)
varstr = (varstr[:max_label_size - 3 - len(suffix)] +
'...' +
suffix)
var_str[var] = varstr var_str[var] = varstr
var_id[var] = str(id(var))
all_strings.add(varstr) all_strings.add(varstr)
return varstr return varstr, var_id[var]
apply_name_cache = {} apply_name_cache = {}
def apply_name(node): def apply_name(node):
...@@ -879,7 +868,7 @@ def pydotprint(fct, outfile=None, ...@@ -879,7 +868,7 @@ def pydotprint(fct, outfile=None,
c2.add_node(nw_node) c2.add_node(nw_node)
for idx, var in enumerate(node.inputs): for idx, var in enumerate(node.inputs):
varstr = var_name(var) varstr, varid = var_name(var)
label = "" label = ""
if len(node.inputs) > 1: if len(node.inputs) > 1:
label = str(idx) label = str(idx)
...@@ -899,15 +888,19 @@ def pydotprint(fct, outfile=None, ...@@ -899,15 +888,19 @@ def pydotprint(fct, outfile=None,
# Mixing blue and green give cyan! (input and output var) # Mixing blue and green give cyan! (input and output var)
color = "cyan" color = "cyan"
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, g.add_node(pd.Node(varid,
style='filled', style='filled',
fillcolor=color, fillcolor=color,
label=varstr,
shape=var_shape)) shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color=color, shape=var_shape)) g.add_node(pd.Node(varid,
g.add_edge(pd.Edge(varstr, astr, **param)) color=color,
label=varstr,
shape=var_shape))
g.add_edge(pd.Edge(varid, astr, **param))
elif var.name or not compact or var in outputs: elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varstr, astr, **param)) g.add_edge(pd.Edge(varid, astr, **param))
else: else:
# no name, so we don't make a var ellipse # no name, so we don't make a var ellipse
if label: if label:
...@@ -919,7 +912,7 @@ def pydotprint(fct, outfile=None, ...@@ -919,7 +912,7 @@ def pydotprint(fct, outfile=None,
g.add_edge(pd.Edge(apply_name(var.owner), astr, **param)) g.add_edge(pd.Edge(apply_name(var.owner), astr, **param))
for idx, var in enumerate(node.outputs): for idx, var in enumerate(node.outputs):
varstr = var_name(var) varstr, varid = var_name(var)
out = var in outputs out = var in outputs
label = "" label = ""
if len(node.outputs) > 1: if len(node.outputs) > 1:
...@@ -930,20 +923,25 @@ def pydotprint(fct, outfile=None, ...@@ -930,20 +923,25 @@ def pydotprint(fct, outfile=None,
if label: if label:
param['label'] = label param['label'] = label
if out or var in input_update: if out or var in input_update:
g.add_edge(pd.Edge(astr, varstr, **param)) g.add_edge(pd.Edge(astr, varid, **param))
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, style='filled', g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='blue', shape=var_shape)) fillcolor='blue', shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='blue', shape=var_shape)) g.add_node(pd.Node(varid, color='blue',
label=varstr,
shape=var_shape))
elif len(var.clients) == 0: elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, **param)) g.add_edge(pd.Edge(astr, varid, **param))
# grey mean that output var isn't used # grey mean that output var isn't used
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, style='filled', g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='grey', shape=var_shape)) fillcolor='grey', shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape)) g.add_node(pd.Node(varid, label=varstr,
color='grey', shape=var_shape))
elif var.name or not compact: elif var.name or not compact:
if not(not compact): if not(not compact):
if label: if label:
...@@ -952,16 +950,16 @@ def pydotprint(fct, outfile=None, ...@@ -952,16 +950,16 @@ def pydotprint(fct, outfile=None,
if len(label) > max_label_size: if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...' label = label[:max_label_size - 3] + '...'
param['label'] = label param['label'] = label
g.add_edge(pd.Edge(astr, varstr, **param)) g.add_edge(pd.Edge(astr, varid, **param))
g.add_node(pd.Node(varstr, shape=var_shape)) g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
# else: # else:
# don't add egde here as it is already added from the inputs. # don't add egde here as it is already added from the inputs.
# The var that represent updates, must be linked to the input var. # The var that represent updates, must be linked to the input var.
for sha, up in input_update.items(): for sha, up in input_update.items():
shastr = var_name(sha) _, shaid = var_name(sha)
upstr = var_name(up) _, upid = var_name(up)
g.add_edge(pd.Edge(shastr, upstr, label="UPDATE", color="blue")) g.add_edge(pd.Edge(shaid, upid, label="UPDATE", color="blue"))
if cond_highlight: if cond_highlight:
g.add_subgraph(c1) g.add_subgraph(c1)
...@@ -1069,7 +1067,7 @@ def pydotprint_variables(vars, ...@@ -1069,7 +1067,7 @@ def pydotprint_variables(vars,
if len(varstr) > max_label_size: if len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...' varstr = varstr[:max_label_size - 3] + '...'
var_str[var] = varstr var_str[var] = varstr
return varstr return varstr, varlabel
def apply_name(node): def apply_name(node):
name = str(node.op).replace(':', '_') name = str(node.op).replace(':', '_')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论