提交 9ccf8eda authored 作者: Frederic's avatar Frederic

Make pydotprint support Variables and apply node

Fix gh-2108
上级 77523c8e
...@@ -528,7 +528,8 @@ def pydotprint(fct, outfile=None, ...@@ -528,7 +528,8 @@ def pydotprint(fct, outfile=None,
""" """
Print to a file (png format) the graph of a compiled theano function's ops. Print to a file (png format) the graph of a compiled theano function's ops.
:param fct: the theano fct returned by theano.function. :param fct: a compiled Theano function, a Variable, an Apply or
a list of Variable.
:param outfile: the output file where to put the graph. :param outfile: the output file where to put the graph.
:param compact: if True, will remove intermediate var that don't have name. :param compact: if True, will remove intermediate var that don't have name.
:param format: the file format of the output. :param format: the file format of the output.
...@@ -597,9 +598,19 @@ def pydotprint(fct, outfile=None, ...@@ -597,9 +598,19 @@ def pydotprint(fct, outfile=None,
outputs = fct.outputs outputs = fct.outputs
topo = fct.toposort() topo = fct.toposort()
else: else:
raise ValueError(('pydotprint expects as input a theano.function or ' if isinstance(fct, gof.Variable):
'the FunctionGraph of a function!'), fct) fct = [fct]
elif isinstance(fct, gof.Apply):
fct = fct.outputs
assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct)
inputs = gof.graph.inputs(fct)
fct = gof.FunctionGraph(inputs=gof.graph.inputs(fct),
outputs=fct)
mode = None
profile = None
outputs = fct.outputs
topo = fct.toposort()
if not pydot_imported: if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot" raise RuntimeError("Failed to import pydot. You must install pydot"
" for `pydotprint` to work.") " for `pydotprint` to work.")
......
...@@ -68,6 +68,7 @@ def test_pydotprint_variables(): ...@@ -68,6 +68,7 @@ def test_pydotprint_variables():
theano.theano_logger.removeHandler(orig_handler) theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler) theano.theano_logger.addHandler(new_handler)
try: try:
theano.printing.pydotprint(x * 2)
theano.printing.pydotprint_variables(x * 2) theano.printing.pydotprint_variables(x * 2)
finally: finally:
theano.theano_logger.addHandler(orig_handler) theano.theano_logger.addHandler(orig_handler)
...@@ -92,14 +93,13 @@ def test_pydotprint_long_name(): ...@@ -92,14 +93,13 @@ def test_pydotprint_long_name():
f = theano.function([x], [x * 2, x + x], mode=mode) f = theano.function([x], [x * 2, x + x], mode=mode)
f([1, 2, 3, 4]) f([1, 2, 3, 4])
s = StringIO()
new_handler = logging.StreamHandler(s)
new_handler.setLevel(logging.DEBUG)
orig_handler = theano.logging_default_handler
theano.printing.pydotprint(f, max_label_size=5, theano.printing.pydotprint(f, max_label_size=5,
print_output_file=False, print_output_file=False,
assert_nb_all_strings=6) assert_nb_all_strings=6)
theano.printing.pydotprint([x * 2, x + x],
max_label_size=5,
print_output_file=False,
assert_nb_all_strings=8)
def test_pydotprint_profile(): def test_pydotprint_profile():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论