提交 9ccf8eda authored 作者: Frederic's avatar Frederic

Make pydotprint support Variables and apply node

Fix gh-2108
上级 77523c8e
......@@ -528,7 +528,8 @@ def pydotprint(fct, outfile=None,
"""
Print to a file (png format) the graph of a compiled theano function's ops.
:param fct: the theano fct returned by theano.function.
:param fct: a compiled Theano function, a Variable, an Apply or
a list of Variable.
:param outfile: the output file where to put the graph.
:param compact: if True, will remove intermediate var that don't have name.
:param format: the file format of the output.
......@@ -597,9 +598,19 @@ def pydotprint(fct, outfile=None,
outputs = fct.outputs
topo = fct.toposort()
else:
raise ValueError(('pydotprint expects as input a theano.function or '
'the FunctionGraph of a function!'), fct)
if isinstance(fct, gof.Variable):
fct = [fct]
elif isinstance(fct, gof.Apply):
fct = fct.outputs
assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct)
inputs = gof.graph.inputs(fct)
fct = gof.FunctionGraph(inputs=gof.graph.inputs(fct),
outputs=fct)
mode = None
profile = None
outputs = fct.outputs
topo = fct.toposort()
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" for `pydotprint` to work.")
......
......@@ -68,6 +68,7 @@ def test_pydotprint_variables():
theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
try:
theano.printing.pydotprint(x * 2)
theano.printing.pydotprint_variables(x * 2)
finally:
theano.theano_logger.addHandler(orig_handler)
......@@ -92,14 +93,13 @@ def test_pydotprint_long_name():
f = theano.function([x], [x * 2, x + x], mode=mode)
f([1, 2, 3, 4])
s = StringIO()
new_handler = logging.StreamHandler(s)
new_handler.setLevel(logging.DEBUG)
orig_handler = theano.logging_default_handler
theano.printing.pydotprint(f, max_label_size=5,
print_output_file=False,
assert_nb_all_strings=6)
theano.printing.pydotprint([x * 2, x + x],
max_label_size=5,
print_output_file=False,
assert_nb_all_strings=8)
def test_pydotprint_profile():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论