提交 0563dc8f authored 作者: ChienliMa's avatar ChienliMa

Add test for VM_Linker; Add test for SharedVariable with/without updates;…

Add test for VM_Linker; Add test for SharedVariable with/without updates; Simplify codes and delete extra line.
上级 55815f17
...@@ -1826,7 +1826,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1826,7 +1826,6 @@ class _Linker(gof.link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler=None, input_storage=None, def make_all(self, profiler=None, input_storage=None,
output_storage=None, storage_map=None): output_storage=None, storage_map=None):
# can't import at toplevel because of circular import TODO: # can't import at toplevel because of circular import TODO:
......
...@@ -567,7 +567,6 @@ class Function(object): ...@@ -567,7 +567,6 @@ class Function(object):
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
""" """
if not share_memory: if not share_memory:
return self.__copy__() return self.__copy__()
else: else:
...@@ -578,26 +577,26 @@ class Function(object): ...@@ -578,26 +577,26 @@ class Function(object):
# copy fgraph and get memo # copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# construct new storage_map that map new variable to old storage # Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one # so that the ensuing function shares storage with the original one
new_storage_map = {}
storage_map = self.fn.storage_map
# TODO: We could share the output storage, but we must make sure # TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This # 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to # is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too. # use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure # But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map. # it is well tested, we don't share the part of the storage_map.
storage_map = self.fn.storage_map
new_storage_map = {}
for key in storage_map.keys(): for key in storage_map.keys():
if key not in maker.fgraph.outputs: if key not in maker.fgraph.outputs:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
input_storage = [] input_storage = []
for in_ori, in_cpy in zip(maker.inputs, ins): assert len(ins) == len(fg_cpy.inputs)
# Since we reuse original Out instances, the copied In for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
# instances should use the original variabls as their variables # Since we reuse original Out instances, copied In instances
# and updates. Otherwise the compilation will fail at function # should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function
# FunctionMaker._check_unused_inputs() # FunctionMaker._check_unused_inputs()
in_cpy.variable = in_ori.variable in_cpy.variable = in_ori.variable
in_cpy.update = in_ori.update in_cpy.update = in_ori.update
...@@ -611,12 +610,10 @@ class Function(object): ...@@ -611,12 +610,10 @@ class Function(object):
storage = getattr(in_cpy, 'value', None) storage = getattr(in_cpy, 'value', None)
input_storage.append(storage) input_storage.append(storage)
# pop out input_storage in storage_map and only use storage of In # pop out input_storage in storage_map and only use storage of
# instances to initialize the make, so that we can avoid # In instances to initialize the make, to avoid storage
# storage conflictions in link.map_storage() # conflictions in link.map_storage()
assert len(fg_cpy.inputs) == len(input_storage) new_storage_map.pop(in_v)
for in_var in fg_cpy.inputs:
new_storage_map.pop(in_var)
# reinitialize new maker and create new function # reinitialize new maker and create new function
return maker.__class__(inputs=ins, outputs=maker.outputs, return maker.__class__(inputs=ins, outputs=maker.outputs,
......
...@@ -243,38 +243,42 @@ class T_function(unittest.TestCase): ...@@ -243,38 +243,42 @@ class T_function(unittest.TestCase):
def test_copy_share_memory(self): def test_copy_share_memory(self):
x = T.fscalar('x') x = T.fscalar('x')
y = T.tanh((x+2)/(x-0.2)**2) # SharedVariable for tests, one of them has update
y = theano.shared(value=1)
# test for PerformaLinker, will cover VM_linker later z = theano.shared(value=2)
ori = theano.function([x], [y], mode="FAST_COMPILE") out = T.tanh((x+y+2)/(x+z-0.2)**2)
cpy = ori.copy(share_memory=True)
# Test for different linkers
# test if memories shared for mode in ["FAST_RUN","FAST_COMPILE"]:
storage_map_ori = ori.fn.storage_map ori = theano.function([x], [out], mode=mode,updates={z:z+1})
storage_map_cpy = cpy.fn.storage_map cpy = ori.copy(share_memory=True)
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph # Test if memories shared
storage_map_ori = ori.fn.storage_map
# assert intermediate and Constants storages are shared. storage_map_cpy = cpy.fn.storage_map
# and output stoarges are not shared fgraph_ori = ori.maker.fgraph
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs fgraph_cpy = cpy.maker.fgraph
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys(): # Assert intermediate and Constants storages are shared.
storage = storage_map_cpy[key] # and output stoarges are not shared
storage_is_shared = any([ storage is s for s in ori_storages]) i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
if key not in i_o_variables or isinstance(key, theano.tensor.Constant): ori_storages = storage_map_ori.values()
self.assertTrue(storage_is_shared) for key in storage_map_cpy.keys():
elif key in fgraph_cpy.outputs: storage = storage_map_cpy[key]
self.assertFalse(storage_is_shared) storage_is_shared = any([ storage is s for s in ori_storages])
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
# assert storages of SharedVariable without updates are shared self.assertTrue(storage_is_shared)
for (input, _1, _2), here, there in zip(ori.indices, elif key in fgraph_cpy.outputs:
ori.input_storage, self.assertFalse(storage_is_shared)
cpy.input_storage):
if not input.mutable: # Assert storages of SharedVariable without updates are shared
self.assertTrue(here.data is there.data) for (input, _1, _2), here, there in zip(ori.indices,
else: ori.input_storage,
self.assertFalse(here.data is there.data) cpy.input_storage):
if not input.mutable:
self.assertTrue(here.data is there.data)
else:
self.assertFalse(here.data is there.data)
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论