提交 f7e91849 authored 作者: Frederic Bastien's avatar Frederic Bastien

put a max size of the label in pydotprint*

上级 b2641d3b
...@@ -391,8 +391,9 @@ default_colorCodes = {'GpuFromHost' : 'red', ...@@ -391,8 +391,9 @@ default_colorCodes = {'GpuFromHost' : 'red',
def pydotprint(fct, outfile=None, def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False, compact=True, format='png', with_ids=False,
high_contrast=False, cond_highlight = None, colorCodes = None): high_contrast=False, cond_highlight = None, colorCodes = None,
max_label_size=50):
""" """
print to a file in png format the graph of op of a compile theano fct. print to a file in png format the graph of op of a compile theano fct.
...@@ -491,8 +492,6 @@ def pydotprint(fct, outfile=None, ...@@ -491,8 +492,6 @@ def pydotprint(fct, outfile=None,
dstr = 'val='+str(numpy.asarray(var.data)) dstr = 'val='+str(numpy.asarray(var.data))
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30:
dstr = dstr[:27]+'...'
varstr = '%s [%s]'% (dstr, str(var.type)) varstr = '%s [%s]'% (dstr, str(var.type))
elif var in input_update and input_update[var].variable.name is not None: elif var in input_update and input_update[var].variable.name is not None:
varstr = input_update[var].variable.name+" "+str(var.type) varstr = input_update[var].variable.name+" "+str(var.type)
...@@ -501,6 +500,8 @@ def pydotprint(fct, outfile=None, ...@@ -501,6 +500,8 @@ def pydotprint(fct, outfile=None,
varstr = str(var.type) varstr = str(var.type)
if (varstr in all_strings) or with_ids: if (varstr in all_strings) or with_ids:
varstr += ' id=' + str(len(var_str)) varstr += ' id=' + str(len(var_str))
if len(varstr) > max_label_size:
varstr = varstr[:max_label_size-3]+'...'
var_str[var]=varstr var_str[var]=varstr
all_strings.add(varstr) all_strings.add(varstr)
...@@ -522,6 +523,8 @@ def pydotprint(fct, outfile=None, ...@@ -522,6 +523,8 @@ def pydotprint(fct, outfile=None,
else: pf = time*100/mode.fct_call_time[fct] else: pf = time*100/mode.fct_call_time[fct]
prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf) prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf)
applystr = str(node.op).replace(':','_') applystr = str(node.op).replace(':','_')
if len(applystr)>max_label_size:
applystr = applystr[:max_label_size-3]+'...'
if (applystr in all_strings) or with_ids: if (applystr in all_strings) or with_ids:
applystr = applystr+' id='+str(topo.index(node)) applystr = applystr+' id='+str(topo.index(node))
applystr += prof_str applystr += prof_str
...@@ -567,6 +570,8 @@ def pydotprint(fct, outfile=None, ...@@ -567,6 +570,8 @@ def pydotprint(fct, outfile=None,
for id,var in enumerate(node.inputs): for id,var in enumerate(node.inputs):
varstr=var_name(var) varstr=var_name(var)
label=str(var.type) label=str(var.type)
if len(label)>max_label_size:
label = label[:max_label_size-3]+'...'
if len(node.inputs)>1: if len(node.inputs)>1:
label=str(id)+' '+label label=str(id)+' '+label
if var.owner is None: if var.owner is None:
...@@ -590,6 +595,8 @@ def pydotprint(fct, outfile=None, ...@@ -590,6 +595,8 @@ def pydotprint(fct, outfile=None,
label=str(var.type) label=str(var.type)
if len(node.outputs)>1: if len(node.outputs)>1:
label=str(id)+' '+label label=str(id)+' '+label
if len(label)>max_label_size:
label = label[:max_label_size-3]+'...'
if out: if out:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast: if high_contrast:
...@@ -627,7 +634,8 @@ def pydotprint_variables(vars, ...@@ -627,7 +634,8 @@ def pydotprint_variables(vars,
outfile=None, outfile=None,
format='png', format='png',
depth = -1, depth = -1,
high_contrast = True, colorCodes = None): high_contrast = True, colorCodes = None,
max_label_size=50):
''' Identical to pydotprint just that it starts from a variable instead ''' Identical to pydotprint just that it starts from a variable instead
of a compiled function. Could be useful ? ''' of a compiled function. Could be useful ? '''
...@@ -657,18 +665,21 @@ def pydotprint_variables(vars, ...@@ -657,18 +665,21 @@ def pydotprint_variables(vars,
dstr = 'val='+str(var.data) dstr = 'val='+str(var.data)
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30:
dstr = dstr[:27]+'...'
varstr = '%s [%s]'% (dstr, str(var.type)) varstr = '%s [%s]'% (dstr, str(var.type))
else: else:
#a var id is needed as otherwise var with the same type will be merged in the graph. #a var id is needed as otherwise var with the same type will be merged in the graph.
varstr = str(var.type) varstr = str(var.type)
if len(dstr) > max_label_size:
dstr = dstr[:max_label_size-1]+'...'
varstr += ' ' + str(len(var_str)) varstr += ' ' + str(len(var_str))
var_str[var]=varstr var_str[var]=varstr
return varstr return varstr
def apply_name(node): def apply_name(node):
return str(node.op).replace(':','_') name = str(node.op).replace(':','_')
if len(name) > max_label_size:
name = name[:max_label_size-3]+'...'
return name
def plot_apply(app, d): def plot_apply(app, d):
if d == 0: if d == 0:
...@@ -676,6 +687,8 @@ def pydotprint_variables(vars, ...@@ -676,6 +687,8 @@ def pydotprint_variables(vars,
if app in my_list: if app in my_list:
return return
astr = apply_name(app) + '_' + str(len(my_list.keys())) astr = apply_name(app) + '_' + str(len(my_list.keys()))
if len(astr) > max_label_size:
astr = astr[:max_label_size-3]+'...'
my_list[app] = astr my_list[app] = astr
use_color = None use_color = None
...@@ -695,6 +708,8 @@ def pydotprint_variables(vars, ...@@ -695,6 +708,8 @@ def pydotprint_variables(vars,
for i,nd in enumerate(app.inputs): for i,nd in enumerate(app.inputs):
if nd not in my_list: if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys())) varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size-3]+'...'
my_list[nd] = varastr my_list[nd] = varastr
if nd.owner is not None: if nd.owner is not None:
g.add_node(pd.Node(varastr)) g.add_node(pd.Node(varastr))
...@@ -713,6 +728,8 @@ def pydotprint_variables(vars, ...@@ -713,6 +728,8 @@ def pydotprint_variables(vars,
for i,nd in enumerate(app.outputs): for i,nd in enumerate(app.outputs):
if nd not in my_list: if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys())) varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size-3]+'...'
my_list[nd] = varastr my_list[nd] = varastr
color = None color = None
if nd in vars: if nd in vars:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论