提交 255792af authored 作者: Lijun Xue's avatar Lijun Xue

update test file

上级 ec6483da
...@@ -14,42 +14,32 @@ class Test_reallocation(unittest.TestCase): ...@@ -14,42 +14,32 @@ class Test_reallocation(unittest.TestCase):
def test_reallocation(self): def test_reallocation(self):
pre_config = theano.config.allow_gc x = T.scalar('x')
y = T.scalar('y')
try: z = T.tanh(3*x + y) + T.cosh(x + 5*y)
theano.config.allow_gc = False
x = T.scalar('x') m = theano.compile.get_mode(theano.Mode(linker='vm_nogc'))
y = T.scalar('y') m = m.excluding('fusion', 'inplace')
z = T.tanh(x + y) + T.cosh(x + y) f = theano.function([x, y], z, name="test_reduce_memory",
mode=m)
if theano.config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]: output = f(1, 2)
m = "FAST_RUN" storage_map = f.fn.storage_map
else:
m = None
m = theano.compile.get_mode(m).excluding('fusion', 'inplace') def check_storage(storage_map):
from theano.tensor.var import TensorConstant
f = theano.function([x, y], z, name="test_reduce_memory", for i in storage_map.keys():
mode=m) if not isinstance(i, TensorConstant):
output = f(1, 2)
storage_map = f.fn.storage_map
def check_storage(storage_map):
for i in storage_map.keys():
keys_copy = storage_map.keys()[:] keys_copy = storage_map.keys()[:]
keys_copy.remove(i) keys_copy.remove(i)
for o in keys_copy: for o in keys_copy:
if storage_map[i][0] == storage_map[o][0]: if storage_map[i][0] and storage_map[i][0] == storage_map[o][0]:
return True return [True, storage_map[o][0]]
return False return [False, None]
assert check_storage(storage_map)
finally: assert check_storage(storage_map)[0]
theano.config.allow_gc = pre_config
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论