提交 b7cbe8fb authored 作者: Frederic Bastien's avatar Frederic Bastien

Added an option in pydotprint* to put less information in the graph when we have…

Added an option in pydotprint* to put less information in the graph when we have gived a name to a variable.
上级 a102d8e7
...@@ -392,8 +392,10 @@ default_colorCodes = {'GpuFromHost' : 'red', ...@@ -392,8 +392,10 @@ default_colorCodes = {'GpuFromHost' : 'red',
def pydotprint(fct, outfile=None, def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False, compact=True, format='png', with_ids=False,
high_contrast=True, cond_highlight = None, colorCodes = None, high_contrast=True, cond_highlight=None, colorCodes=None,
max_label_size=50, scan_graphs = False): max_label_size=50, scan_graphs=False,
var_with_name_simple=False
):
""" """
print to a file in png format the graph of op of a compile theano fct. print to a file in png format the graph of op of a compile theano fct.
...@@ -493,6 +495,9 @@ def pydotprint(fct, outfile=None, ...@@ -493,6 +495,9 @@ def pydotprint(fct, outfile=None,
return var_str[var] return var_str[var]
if var.name is not None: if var.name is not None:
if var_with_name_simple:
varstr = var.name
else:
varstr = 'name='+var.name+" "+str(var.type) varstr = 'name='+var.name+" "+str(var.type)
elif isinstance(var,gof.Constant): elif isinstance(var,gof.Constant):
dstr = 'val='+str(numpy.asarray(var.data)) dstr = 'val='+str(numpy.asarray(var.data))
...@@ -500,6 +505,9 @@ def pydotprint(fct, outfile=None, ...@@ -500,6 +505,9 @@ def pydotprint(fct, outfile=None,
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
varstr = '%s [%s]'% (dstr, str(var.type)) varstr = '%s [%s]'% (dstr, str(var.type))
elif var in input_update and input_update[var].variable.name is not None: elif var in input_update and input_update[var].variable.name is not None:
if var_with_name_simple:
varstr = input_update[var].variable.name
else:
varstr = input_update[var].variable.name+" "+str(var.type) varstr = input_update[var].variable.name+" "+str(var.type)
else: else:
#a var id is needed as otherwise var with the same type will be merged in the graph. #a var id is needed as otherwise var with the same type will be merged in the graph.
...@@ -667,7 +675,8 @@ def pydotprint_variables(vars, ...@@ -667,7 +675,8 @@ def pydotprint_variables(vars,
format='png', format='png',
depth=-1, depth=-1,
high_contrast=True, colorCodes=None, high_contrast=True, colorCodes=None,
max_label_size=50): max_label_size=50,
var_with_name_simple=False):
''' Identical to pydotprint just that it starts from a variable instead ''' Identical to pydotprint just that it starts from a variable instead
of a compiled function. Could be useful ? ''' of a compiled function. Could be useful ? '''
...@@ -692,7 +701,10 @@ def pydotprint_variables(vars, ...@@ -692,7 +701,10 @@ def pydotprint_variables(vars,
return var_str[var] return var_str[var]
if var.name is not None: if var.name is not None:
varstr = 'name='+var.name if var_with_name_simple:
varstr = var.name
else:
varstr = 'name='+var.name+" "+str(var.type)
elif isinstance(var,gof.Constant): elif isinstance(var,gof.Constant):
dstr = 'val='+str(var.data) dstr = 'val='+str(var.data)
if '\n' in dstr: if '\n' in dstr:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论