提交 0563dc8f authored 作者: ChienliMa's avatar ChienliMa

Add test for VM_Linker; Add test for SharedVariable with/without updates;…

Add test for VM_Linker; Add test for SharedVariable with/without updates; Simplify codes and delete extra line.
上级 55815f17
......@@ -1826,7 +1826,6 @@ class _Linker(gof.link.LocalLinker):
self.no_recycling = no_recycling
return self
def make_all(self, profiler=None, input_storage=None,
output_storage=None, storage_map=None):
# can't import at toplevel because of circular import TODO:
......
......@@ -567,7 +567,6 @@ class Function(object):
Returns:
func -- Copied theano.Function
"""
if not share_memory:
return self.__copy__()
else:
......@@ -578,26 +577,26 @@ class Function(object):
# copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# construct new storage_map that map new variable to old storage
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
new_storage_map = {}
storage_map = self.fn.storage_map
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
storage_map = self.fn.storage_map
new_storage_map = {}
for key in storage_map.keys():
if key not in maker.fgraph.outputs:
new_storage_map[memo[key]] = storage_map[key]
input_storage = []
for in_ori, in_cpy in zip(maker.inputs, ins):
# Since we reuse original Out instances, the copied In
# instances should use the original variabls as their variables
# and updates. Otherwise the compilation will fail at function
assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
# Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function
# FunctionMaker._check_unused_inputs()
in_cpy.variable = in_ori.variable
in_cpy.update = in_ori.update
......@@ -611,12 +610,10 @@ class Function(object):
storage = getattr(in_cpy, 'value', None)
input_storage.append(storage)
# pop out input_storage in storage_map and only use storage of In
# instances to initialize the make, so that we can avoid
# storage conflictions in link.map_storage()
assert len(fg_cpy.inputs) == len(input_storage)
for in_var in fg_cpy.inputs:
new_storage_map.pop(in_var)
# pop out input_storage in storage_map and only use storage of
# In instances to initialize the make, to avoid storage
# conflictions in link.map_storage()
new_storage_map.pop(in_v)
# reinitialize new maker and create new function
return maker.__class__(inputs=ins, outputs=maker.outputs,
......
......@@ -243,38 +243,42 @@ class T_function(unittest.TestCase):
def test_copy_share_memory(self):
x = T.fscalar('x')
y = T.tanh((x+2)/(x-0.2)**2)
# test for PerformaLinker, will cover VM_linker later
ori = theano.function([x], [y], mode="FAST_COMPILE")
cpy = ori.copy(share_memory=True)
# test if memories shared
storage_map_ori = ori.fn.storage_map
storage_map_cpy = cpy.fn.storage_map
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph
# assert intermediate and Constants storages are shared.
# and output stoarges are not shared
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys():
storage = storage_map_cpy[key]
storage_is_shared = any([ storage is s for s in ori_storages])
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
self.assertTrue(storage_is_shared)
elif key in fgraph_cpy.outputs:
self.assertFalse(storage_is_shared)
# assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices,
ori.input_storage,
cpy.input_storage):
if not input.mutable:
self.assertTrue(here.data is there.data)
else:
self.assertFalse(here.data is there.data)
# SharedVariable for tests, one of them has update
y = theano.shared(value=1)
z = theano.shared(value=2)
out = T.tanh((x+y+2)/(x+z-0.2)**2)
# Test for different linkers
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1})
cpy = ori.copy(share_memory=True)
# Test if memories shared
storage_map_ori = ori.fn.storage_map
storage_map_cpy = cpy.fn.storage_map
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph
# Assert intermediate and Constants storages are shared.
# and output stoarges are not shared
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys():
storage = storage_map_cpy[key]
storage_is_shared = any([ storage is s for s in ori_storages])
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
self.assertTrue(storage_is_shared)
elif key in fgraph_cpy.outputs:
self.assertFalse(storage_is_shared)
# Assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices,
ori.input_storage,
cpy.input_storage):
if not input.mutable:
self.assertTrue(here.data is there.data)
else:
self.assertFalse(here.data is there.data)
def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论