提交 2fc8a039 authored 作者: rman@rpad's avatar rman@rpad

corrected bug with map

上级 b6b6449a
...@@ -392,6 +392,10 @@ def scan(fn, sequences=[], info_outputs=[], non_sequences=[], ...@@ -392,6 +392,10 @@ def scan(fn, sequences=[], info_outputs=[], non_sequences=[],
# are required to have any sort of time taps # are required to have any sort of time taps
# we just need to update the number of actual outputs # we just need to update the number of actual outputs
n_outs = len(ls_outputs) n_outs = len(ls_outputs)
# other updates :
for i in xrange(n_outs):
info_outs += [ dict() ]
else: else:
raise ValueError('There has been a terrible mistake in our input arguments' raise ValueError('There has been a terrible mistake in our input arguments'
' and scan is totally lost. Make sure that you indicate for every ' ' and scan is totally lost. Make sure that you indicate for every '
......
...@@ -392,5 +392,13 @@ class T_Scan(unittest.TestCase): ...@@ -392,5 +392,13 @@ class T_Scan(unittest.TestCase):
assert compareArrays(f2(v_u), v_u+3) assert compareArrays(f2(v_u), v_u+3)
def test_map(self):
from theano.scan import map as T_map
v = theano.tensor.vector()
abs_expr,abs_updates = T_map(lambda x: abs(x), [v])
abser = theano.function([v],abs_expr,updates = abs_updates)
assert compareArrays( abser(numpy.array([1.,-1])), [1.,1.])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论