提交 7b28bc29 authored 作者: Frederic's avatar Frederic

[CRASH]Fix pydotprint with scan_graphs=True when printing the variables

上级 fd695b52
......@@ -1018,7 +1018,11 @@ def pydotprint(fct, outfile=None,
else:
new_name = basename + '_' + str(idx)
new_name = os.path.join(path, new_name + ext)
pydotprint(scan_op.op.fn, new_name, compact, format, with_ids,
if hasattr(scan_op.op, 'fn'):
to_print = scan_op.op.fn
else:
to_print = scan_op.op.outputs
pydotprint(to_print, new_name, compact, format, with_ids,
high_contrast, cond_highlight, colorCodes,
max_label_size, scan_graphs)
......
......@@ -722,3 +722,24 @@ def test_scan_debugprint5():
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
def test_printing_scan():
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = theano.scan(f_pow2,
[],
state,
[],
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([state, n_steps],
output,
updates=updates,
allow_input_downcast=True)
theano.printing.pydotprint(output, scan_graphs=True)
theano.printing.pydotprint(f, scan_graphs=True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论