提交 86082148 authored 作者: --global's avatar --global

Update memory consumption tests to reflect changes in SaveMem

上级 3da5407b
......@@ -2953,7 +2953,14 @@ class T_Scan(unittest.TestCase):
assert scan_nodes is not None
scan_node = scan_nodes[0]
f1 = theano.function(inputs, scan_node.inputs[2])
assert f1().shape[0] == 1
# Originally, the shape would have been 1 due to the SaveMem
# optimization reducing the size to the number of taps (in this case
# 1) provided to the inner function. Now, because of the memory-reuse
# feature in Scan it should be 2 because SaveMem needs to keep a
# larger buffer to avoid aliasing between the inputs and the outputs.
assert f1().shape[0] == 2
gx = theano.tensor.grad(o, x)
f2 = theano.function([], gx)
utt.assert_allclose(f2(), numpy.ones((10,)))
......@@ -2976,7 +2983,14 @@ class T_Scan(unittest.TestCase):
assert scan_nodes is not None
scan_node = scan_nodes[0]
f1 = theano.function(inputs, scan_node.inputs[2])
assert f1().shape[0] == 1
# Originally, the shape would have been 1 due to the SaveMem
# optimization reducing the size to the number of taps (in this case
# 1) provided to the inner function. Now, because of the memory-reuse
# feature in Scan it should be 2 because SaveMem needs to keep a
# larger buffer to avoid aliasing between the inputs and the outputs.
assert f1().shape[0] == 2
gx = theano.tensor.grad(o, x)
f2 = theano.function([], gx)
utt.assert_allclose(f2(), numpy.ones((10,)))
......@@ -3000,7 +3014,14 @@ class T_Scan(unittest.TestCase):
assert scan_nodes is not None
scan_node = scan_nodes[0]
f1 = theano.function(inputs, scan_node.inputs[2])
assert f1().shape[0] == 1
# Originally, the shape would have been 1 due to the SaveMem
# optimization reducing the size to the number of taps (in this case
# 1) provided to the inner function. Now, because of the memory-reuse
# feature in Scan it should be 2 because SaveMem needs to keep a
# larger buffer to avoid aliasing between the inputs and the outputs.
assert f1().shape[0] == 2
gx = theano.tensor.grad(o, x)
f2 = theano.function([], gx)
utt.assert_allclose(f2(), numpy.ones((10,)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论