提交 e100f4a4 authored 作者: Roy Xue's avatar Roy Xue

Merge pull request #8 from nouiz/royxue-reduce_temp

Fix test and make it work
...@@ -349,26 +349,28 @@ def test_reallocation(): ...@@ -349,26 +349,28 @@ def test_reallocation():
x = tensor.scalar('x') x = tensor.scalar('x')
y = tensor.scalar('y') y = tensor.scalar('y')
z = tensor.tanh(3 * x + y) + tensor.cosh(x + 5 * y) z = tensor.tanh(3 * x + y) + tensor.cosh(x + 5 * y)
for l in ['vm_nogc', 'vm', 'vm_nogc', 'vm']:
m = theano.compile.get_mode(theano.Mode(linker='vm_nogc')) m = theano.compile.get_mode(theano.Mode(linker=l))
m = m.excluding('fusion', 'inplace') m = m.excluding('fusion', 'inplace')
f = theano.function([x, y], z, name="test_reduce_memory", f = theano.function([x, y], z, name="test_reduce_memory",
mode=m) mode=m)
output = f(1, 2) output = f(1, 2)
assert output assert output
storage_map = f.fn.storage_map storage_map = f.fn.storage_map
def check_storage(storage_map): def check_storage(storage_map):
from theano.tensor.var import TensorConstant from theano.tensor.var import TensorConstant
for i in storage_map.keys(): for i in storage_map.keys():
if not isinstance(i, TensorConstant): if not isinstance(i, TensorConstant):
keys_copy = storage_map.keys()[:] keys_copy = storage_map.keys()[:]
keys_copy.remove(i) keys_copy.remove(i)
for o in keys_copy: for o in keys_copy:
if (storage_map[i][0] and if (storage_map[i][0] and
storage_map[i][0] == storage_map[o][0]): storage_map[i][0] is storage_map[o][0]):
return [True, storage_map[o][0]] return [True, storage_map[o][0]]
return [False, None] return [False, None]
assert check_storage(storage_map)[0] assert check_storage(storage_map)[0]
assert len(set([id(v) for v in
storage_map.values()])) < len(storage_map)
...@@ -1006,7 +1006,7 @@ class VM_Linker(link.LocalLinker): ...@@ -1006,7 +1006,7 @@ class VM_Linker(link.LocalLinker):
lazy = not all([(not th.lazy) for th in thunks]) lazy = not all([(not th.lazy) for th in thunks])
if not (lazy or (config.profile and config.profile_memory) or self.use_cloop or self.callback): if not (lazy or (config.profile and config.profile_memory) or self.use_cloop or self.callback):
for pair in reallocated_info.values(): for pair in reallocated_info.values():
storage_map[pair[1]][0] = storage_map[pair[0]][0] storage_map[pair[1]] = storage_map[pair[0]]
computed, last_user = link.gc_helper(order) computed, last_user = link.gc_helper(order)
if self.allow_gc: if self.allow_gc:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论