提交 f7e91849 authored 作者: Frederic Bastien's avatar Frederic Bastien

put a max size of the label in pydotprint*

上级 b2641d3b
......@@ -391,8 +391,9 @@ default_colorCodes = {'GpuFromHost' : 'red',
def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False,
high_contrast=False, cond_highlight = None, colorCodes = None):
compact=True, format='png', with_ids=False,
high_contrast=False, cond_highlight = None, colorCodes = None,
max_label_size=50):
"""
print to a file in png format the graph of op of a compile theano fct.
......@@ -491,8 +492,6 @@ def pydotprint(fct, outfile=None,
dstr = 'val='+str(numpy.asarray(var.data))
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30:
dstr = dstr[:27]+'...'
varstr = '%s [%s]'% (dstr, str(var.type))
elif var in input_update and input_update[var].variable.name is not None:
varstr = input_update[var].variable.name+" "+str(var.type)
......@@ -501,6 +500,8 @@ def pydotprint(fct, outfile=None,
varstr = str(var.type)
if (varstr in all_strings) or with_ids:
varstr += ' id=' + str(len(var_str))
if len(varstr) > max_label_size:
varstr = varstr[:max_label_size-3]+'...'
var_str[var]=varstr
all_strings.add(varstr)
......@@ -522,6 +523,8 @@ def pydotprint(fct, outfile=None,
else: pf = time*100/mode.fct_call_time[fct]
prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf)
applystr = str(node.op).replace(':','_')
if len(applystr)>max_label_size:
applystr = applystr[:max_label_size-3]+'...'
if (applystr in all_strings) or with_ids:
applystr = applystr+' id='+str(topo.index(node))
applystr += prof_str
......@@ -567,6 +570,8 @@ def pydotprint(fct, outfile=None,
for id,var in enumerate(node.inputs):
varstr=var_name(var)
label=str(var.type)
if len(label)>max_label_size:
label = label[:max_label_size-3]+'...'
if len(node.inputs)>1:
label=str(id)+' '+label
if var.owner is None:
......@@ -590,6 +595,8 @@ def pydotprint(fct, outfile=None,
label=str(var.type)
if len(node.outputs)>1:
label=str(id)+' '+label
if len(label)>max_label_size:
label = label[:max_label_size-3]+'...'
if out:
g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast:
......@@ -627,7 +634,8 @@ def pydotprint_variables(vars,
outfile=None,
format='png',
depth = -1,
high_contrast = True, colorCodes = None):
high_contrast = True, colorCodes = None,
max_label_size=50):
''' Identical to pydotprint just that it starts from a variable instead
of a compiled function. Could be useful ? '''
......@@ -657,18 +665,21 @@ def pydotprint_variables(vars,
dstr = 'val='+str(var.data)
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30:
dstr = dstr[:27]+'...'
varstr = '%s [%s]'% (dstr, str(var.type))
else:
#a var id is needed as otherwise var with the same type will be merged in the graph.
varstr = str(var.type)
if len(dstr) > max_label_size:
dstr = dstr[:max_label_size-1]+'...'
varstr += ' ' + str(len(var_str))
var_str[var]=varstr
return varstr
def apply_name(node):
return str(node.op).replace(':','_')
name = str(node.op).replace(':','_')
if len(name) > max_label_size:
name = name[:max_label_size-3]+'...'
return name
def plot_apply(app, d):
if d == 0:
......@@ -676,6 +687,8 @@ def pydotprint_variables(vars,
if app in my_list:
return
astr = apply_name(app) + '_' + str(len(my_list.keys()))
if len(astr) > max_label_size:
astr = astr[:max_label_size-3]+'...'
my_list[app] = astr
use_color = None
......@@ -695,6 +708,8 @@ def pydotprint_variables(vars,
for i,nd in enumerate(app.inputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size-3]+'...'
my_list[nd] = varastr
if nd.owner is not None:
g.add_node(pd.Node(varastr))
......@@ -713,6 +728,8 @@ def pydotprint_variables(vars,
for i,nd in enumerate(app.outputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size-3]+'...'
my_list[nd] = varastr
color = None
if nd in vars:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论