提交 2fc8a039 authored 作者: rman@rpad's avatar rman@rpad

corrected bug with map

上级 b6b6449a
......@@ -392,6 +392,10 @@ def scan(fn, sequences=[], info_outputs=[], non_sequences=[],
# are required to have any sort of time taps
# we just need to update the number of actual outputs
n_outs = len(ls_outputs)
# other updates :
for i in xrange(n_outs):
info_outs += [ dict() ]
else:
raise ValueError('There has been a terrible mistake in our input arguments'
' and scan is totally lost. Make sure that you indicate for every '
......
......@@ -392,5 +392,13 @@ class T_Scan(unittest.TestCase):
assert compareArrays(f2(v_u), v_u+3)
def test_map(self):
from theano.scan import map as T_map
v = theano.tensor.vector()
abs_expr,abs_updates = T_map(lambda x: abs(x), [v])
abser = theano.function([v],abs_expr,updates = abs_updates)
assert compareArrays( abser(numpy.array([1.,-1])), [1.,1.])
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论