提交 29af0e5b authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5196 from Thrandis/ccw

Added checks for taps values.
...@@ -449,6 +449,16 @@ def scan(fn, ...@@ -449,6 +449,16 @@ def scan(fn,
getattr(outs_info[i]['initial'], 'name', 'None'), getattr(outs_info[i]['initial'], 'name', 'None'),
i) i)
outs_info[i]['taps'] = [-1] outs_info[i]['taps'] = [-1]
elif outs_info[i].get('taps', None) is not None:
# Check that taps are valid (< 0 and all dfferent)
taps = outs_info[i]['taps']
if len(taps) > len(set(taps)):
raise ValueError(('All the taps must be different in '
' `outputs_info`'), outs_info[i])
for t in taps:
if t >= 0:
raise ValueError(('All the tap values must be '
'smaller than 0.'), outs_info[i])
else: else:
# if a None is provided as the output info we replace it # if a None is provided as the output info we replace it
# with an empty OrdereDict() to simplify handling # with an empty OrdereDict() to simplify handling
......
...@@ -5454,3 +5454,15 @@ def test_constant_folding_n_steps(): ...@@ -5454,3 +5454,15 @@ def test_constant_folding_n_steps():
theano.function([], res)() theano.function([], res)()
finally: finally:
theano.config.on_opt_error = on_opt_error theano.config.on_opt_error = on_opt_error
def test_outputs_taps_check():
"""Checks that errors are raised with bad output_info taps."""
x = tensor.fvector('x')
y = tensor.fvector('y')
f = lambda x, y: [x]
outputs_info = {'initial': y, 'taps': [0]}
assert_raises(ValueError, theano.scan, f, x, outputs_info)
outputs_info = {'initial': y, 'taps': [-1, -1]}
assert_raises(ValueError, theano.scan, f, x, outputs_info)
print('done')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论