提交 32a83c07 authored 作者: ChienliMa's avatar ChienliMa

Small modification to copy(). Testcase is finished. Start debuging.

上级 aa701eaf
...@@ -590,24 +590,33 @@ class Function(object): ...@@ -590,24 +590,33 @@ class Function(object):
# re-initialize new FunctionMaker # re-initialize new FunctionMaker
maker = self.maker maker = self.maker
new_maker = FunctionMaker( inputs=ins, outputs=outs, fgraph=new_fgraph, new_maker = FunctionMaker(inputs=ins, outputs=outs, mode=maker.mode,
mode=maker.mode, profile=maker.profile, fgraph=new_fgraph, profile=maker.profile,
accept_inplace=maker.accept_inplace, accept_inplace=maker.accept_inplace,
function_builder=maker.function_builder, function_builder=maker.function_builder,
on_unused_input=maker.on_unused_input ) on_unused_input=maker.on_unused_input)
# construct new storage_map that map new variable to old storage # construct new storage_map that map new variable to old storage
# so that the ensuing function shares storage with the original one # so that the ensuing function shares storage with the original one
new_storage_map = {} new_storage_map = {}
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
for key in storage_map.keys(): for key in storage_map.keys():
if not isinstance(key, theano.tensor.Constant) and \ # output_storages should not be shared
if key not in self.maker.fgraph.outputs and \
memo.has_key(key): memo.has_key(key):
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
# copy input storages and use new storage_map to link function # copy input storages and link function with new storage_map
input_storage = copy.copy([getattr(input, 'value', None) for input in ins]) input_storage = copy.copy([getattr(i, 'value', None) for i in ins])
new_func = new_maker.create( input_storage, storage_map = new_storage_map ) new_func = new_maker.create(input_storage, storage_map=new_storage_map)
# share immutable SharedVariable's storage
for (input, _1, _2), here, there in zip(self.indices,
self.input_storage,
new_func.input_storage):
if not input.mutable:
there.data = here.data
return new_func return new_func
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
......
...@@ -242,8 +242,34 @@ class T_function(unittest.TestCase): ...@@ -242,8 +242,34 @@ class T_function(unittest.TestCase):
self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore. self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore.
def test_copy_share_memory(self): def test_copy_share_memory(self):
# Todo: finish the test. x = T.fscalar('x')
pass y = T.tanh((x+2)/(x+0.2)**2)
# test for PerformaLinker, will cover VM_linker later
ori = theano.function([x], [y], mode="FAST_COMPILE")
cpy = func.copy(share_memory=True)
# test if memories shared
storage_map_ori = ori.fn.storage_map
storage_map_cpy = cpy.fn.storage_map
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph
# assert intermediate and Constants storages are shared
i_o_variables = fgraph_cpy.inputs
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys()
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
self.assertTrue(storage_map_cpy[key] in ori_storages)
# assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices,
ori.input_storage,
cpy.input_storage):
if not input.mutable:
self.assertTrue(here.data is there.data)
else:
self.assertFalse(here.data is there.data)
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论