提交 9ad495f1 authored 作者: Frederic Bastien's avatar Frederic Bastien

small update to the new pydotprint scan_graph option.

上级 00cd33ba
......@@ -401,6 +401,8 @@ def pydotprint(fct, outfile=None,
:param outfile: the output file where to put the graph.
:param compact: if True, will remove intermediate var that don't have name.
:param format: the file format of the output.
:param with_ids: Print the toposort index of the node in the node name.
and an index number in the variable ellipse.
:param high_contrast: if true, the color that describes the respective
node is filled with its corresponding color, instead of coloring
the border
......@@ -415,7 +417,8 @@ def pydotprint(fct, outfile=None,
:param scan_graphs: if true it will plot the inner graph of each scan op
in files with the same name as the name given for the main
file to which the name of the scan op is concatenated and
some unique id
the index in the toposort of the scan.
This index can be printed in the graph with the option with_ids.
In the graph, box are an Apply Node(the execution of an op) and ellipse are variable.
If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph).
......@@ -630,15 +633,15 @@ def pydotprint(fct, outfile=None,
print 'The output file is available at',outfile
if scan_graphs:
scan_ops = [x for x in fct_env.toposort() if x.op.__class__.__name__ == 'Scan']
scan_ops = [(idx, x) for idx,x in enumerate(fct_env.toposort()) if isinstance(x.op, theano.scan_module.scan_op.Scan)]
path, fn = os.path.split(outfile)
basename = fn.split('.')[0]
basename = '.'.join(fn.split('.')[:-1])
# Safe way of doing things .. a file name may contain multiple .
ext = fn[len(basename):]
for idx, scan_op in enumerate(scan_ops):
# is there a change that name is not defined?
for idx, scan_op in scan_ops:
# is there a chance that name is not defined?
if hasattr(scan_op.op,'name'):
new_name = basename+'_'+scan_op.op.name+'_'+str(idx)
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论