提交 c6875ba4 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #869 from pascanur/fix_sandbox_scan

fix to annoying error message of sandbox/scan
...@@ -262,7 +262,9 @@ def scan(fn, ...@@ -262,7 +262,9 @@ def scan(fn,
# makes code much cleaner for those who do not use taps. Otherwise # makes code much cleaner for those who do not use taps. Otherwise
# they would always had to shape_padleft the initial state .. # they would always had to shape_padleft the initial state ..
# which is ugly # which is ugly
if init_out['taps'] == [-1]:
# Note, 'taps' might not be in the dictionary
if 'taps' in init_out and init_out['taps'] == [-1]:
actual_arg = init_out['membuf'] actual_arg = init_out['membuf']
arg = safe_new(init_out['membuf'][0]) arg = safe_new(init_out['membuf'][0])
......
...@@ -62,10 +62,22 @@ def test_004(): ...@@ -62,10 +62,22 @@ def test_004():
val_sq = numpy.float32([1,2,3,4,5]) val_sq = numpy.float32([1,2,3,4,5])
assert numpy.all(fn(val_sq, 5) == val_sq +1) assert numpy.all(fn(val_sq, 5) == val_sq +1)
def test_005():
sq = theano.tensor.fvector('sq')
nst = theano.tensor.iscalar('nst')
out, _ = scan.scan(lambda s: s+numpy.float32(1),
sequences=sq,
states = [None],
n_steps = nst)
fn = theano.function([sq, nst], out)
val_sq = numpy.float32([1,2,3,4,5])
assert numpy.all(fn(val_sq, 5) == val_sq +1)
if __name__=='__main__': if __name__=='__main__':
test_001() test_001()
test_002() test_002()
test_003() test_003()
test_004() test_004()
test_005()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论