提交 e2f3cebf authored 作者: ChienliMa's avatar ChienliMa

Format Change; Modify inproper comment;Change constructor class;Add test for output storage.

上级 cc53776c
...@@ -576,7 +576,7 @@ class Function(object): ...@@ -576,7 +576,7 @@ class Function(object):
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# use copied ins, outs and fgraph to init a maker # use copied ins, outs and fgraph to init a maker
new_maker = FunctionMaker(inputs=ins, outputs=outs, mode=maker.mode, new_maker = maker.__class__(inputs=ins, outputs=outs, mode=maker.mode,
fgraph=fg_cpy, profile=maker.profile, fgraph=fg_cpy, profile=maker.profile,
accept_inplace=maker.accept_inplace, accept_inplace=maker.accept_inplace,
function_builder=maker.function_builder, function_builder=maker.function_builder,
...@@ -588,8 +588,7 @@ class Function(object): ...@@ -588,8 +588,7 @@ class Function(object):
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
for key in storage_map.keys(): for key in storage_map.keys():
# output_storages should not be shared # output_storages should not be shared
# if key not in self.maker.fgraph.outputs and \ if key not in self.maker.fgraph.outputs and memo.has_key(key):
# memo.has_key(key):
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
# copy input storages if it's mutable # copy input storages if it's mutable
...@@ -602,16 +601,7 @@ class Function(object): ...@@ -602,16 +601,7 @@ class Function(object):
else: else:
input_storage.append( copy.deepcopy[storage]) input_storage.append( copy.deepcopy[storage])
new_func = new_maker.create(input_storage, \ new_func = new_maker.create(input_storage, storage_map=new_storage_map)
storage_map=new_storage_map)
# share immutable SharedVariable's storage
# for (input, _1, _2), here, there in zip(self.indices,
# self.input_storage,
# new_func.input_storage):
# if isinstance(i.variable, theano.tensor.Constant) or \
# not input.mutable:
# there.data = here.data
return new_func return new_func
...@@ -1424,7 +1414,7 @@ class FunctionMaker(object): ...@@ -1424,7 +1414,7 @@ class FunctionMaker(object):
"'%s'.\nValid values are 'raise', " "'%s'.\nValid values are 'raise', "
"'warn', and 'ignore'." % on_unused_input) "'warn', and 'ignore'." % on_unused_input)
def create(self, input_storage=None, trustme=False, storage_map = None): def create(self, input_storage=None, trustme=False, storage_map=None):
""" """
Create a function. Create a function.
...@@ -1505,7 +1495,7 @@ class FunctionMaker(object): ...@@ -1505,7 +1495,7 @@ class FunctionMaker(object):
try: try:
theano.config.traceback.limit = 0 theano.config.traceback.limit = 0
_fn, _i, _o = self.linker.make_thunk( _fn, _i, _o = self.linker.make_thunk(
input_storage=input_storage_lists, storage_map = storage_map) input_storage=input_storage_lists, storage_map=storage_map)
finally: finally:
theano.config.traceback.limit = limit_orig theano.config.traceback.limit = limit_orig
......
...@@ -255,13 +255,17 @@ class T_function(unittest.TestCase): ...@@ -255,13 +255,17 @@ class T_function(unittest.TestCase):
fgraph_ori = ori.maker.fgraph fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph fgraph_cpy = cpy.maker.fgraph
# assert intermediate and Constants storages are shared # assert intermediate and Constants storages are shared.
# and output stoarges are not shared
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values() ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys(): for key in storage_map_cpy.keys():
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
storage = storage_map_cpy[key] storage = storage_map_cpy[key]
self.assertTrue( any([ storage is s for s in ori_storages])) storage_is_shared = any([ storage is s for s in ori_storages])
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
self.assertTrue(storage_is_shared)
elif key in fgraph_cpy.outputs:
self.assertFalse(storage_is_shared)
# assert storages of SharedVariable without updates are shared # assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices, for (input, _1, _2), here, there in zip(ori.indices,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论