提交 255792af authored 作者: Lijun Xue's avatar Lijun Xue

update test file

上级 ec6483da
...@@ -14,22 +14,13 @@ class Test_reallocation(unittest.TestCase): ...@@ -14,22 +14,13 @@ class Test_reallocation(unittest.TestCase):
def test_reallocation(self): def test_reallocation(self):
pre_config = theano.config.allow_gc
try:
theano.config.allow_gc = False
x = T.scalar('x') x = T.scalar('x')
y = T.scalar('y') y = T.scalar('y')
z = T.tanh(x + y) + T.cosh(x + y) z = T.tanh(3*x + y) + T.cosh(x + 5*y)
if theano.config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]: m = theano.compile.get_mode(theano.Mode(linker='vm_nogc'))
m = "FAST_RUN" m = m.excluding('fusion', 'inplace')
else:
m = None
m = theano.compile.get_mode(m).excluding('fusion', 'inplace')
f = theano.function([x, y], z, name="test_reduce_memory", f = theano.function([x, y], z, name="test_reduce_memory",
mode=m) mode=m)
...@@ -38,18 +29,17 @@ class Test_reallocation(unittest.TestCase): ...@@ -38,18 +29,17 @@ class Test_reallocation(unittest.TestCase):
storage_map = f.fn.storage_map storage_map = f.fn.storage_map
def check_storage(storage_map): def check_storage(storage_map):
from theano.tensor.var import TensorConstant
for i in storage_map.keys(): for i in storage_map.keys():
if not isinstance(i, TensorConstant):
keys_copy = storage_map.keys()[:] keys_copy = storage_map.keys()[:]
keys_copy.remove(i) keys_copy.remove(i)
for o in keys_copy: for o in keys_copy:
if storage_map[i][0] == storage_map[o][0]: if storage_map[i][0] and storage_map[i][0] == storage_map[o][0]:
return True return [True, storage_map[o][0]]
return False return [False, None]
assert check_storage(storage_map)
finally: assert check_storage(storage_map)[0]
theano.config.allow_gc = pre_config
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论