提交 d3c57238 authored 作者: ChienliMa's avatar ChienliMa

Add some docs to FunctionGraph and minor changes of function_module.copy()

上级 df77628b
......@@ -595,6 +595,7 @@ class Function(object):
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
for key in storage_map.keys():
<<<<<<< HEAD
if key not in self.maker.fgraph.outputs and
not isinstance(key, theano.tensor.Constant):
new_storage_map[memo[key]] = storage_map[key]
......@@ -610,6 +611,33 @@ class Function(object):
input_storage.append( copy.deepcopy[storage])
new_func = new_maker.create(input_storage, storage_map=new_storage_map)
=======
# output_storages should not be shared
# if key not in self.maker.fgraph.outputs and \
# memo.has_key(key):
new_storage_map[memo[key]] = storage_map[key]
# copy input storages if it's mutable
input_storage = []
for i in self.maker.inputs:
storage = getattr(i, 'value', None)
if isinstance(i.variable, theano.tensor.Constant) or\
not i.mutable:
input_storage.append(storage )
else:
input_storage.append( copy.deepcopy[storage])
new_func = new_maker.create(input_storage, \
storage_map=new_storage_map)
# share immutable SharedVariable's storage
# for (input, _1, _2), here, there in zip(self.indices,
# self.input_storage,
# new_func.input_storage):
# if isinstance(i.variable, theano.tensor.Constant) or \
# not input.mutable:
# there.data = here.data
>>>>>>> 4acf04b... Add some docs to FunctionGraph and minor changes of function_module.copy()
return new_func
......
......@@ -255,17 +255,26 @@ class T_function(unittest.TestCase):
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph
<<<<<<< HEAD
# assert intermediate and Constants storages are shared.
# and output stoarges are not shared
=======
# assert intermediate and Constants storages are shared
>>>>>>> 4acf04b... Add some docs to FunctionGraph and minor changes of function_module.copy()
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys():
storage = storage_map_cpy[key]
storage_is_shared = any([ storage is s for s in ori_storages])
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
<<<<<<< HEAD
self.assertTrue(storage_is_shared)
elif key in fgraph_cpy.outputs:
self.assertFalse(storage_is_shared)
=======
storage = storage_map_cpy[key]
self.assertTrue( any([ storage is s for s in ori_storages]))
>>>>>>> 4acf04b... Add some docs to FunctionGraph and minor changes of function_module.copy()
# assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论