提交 dfe644b7 authored 作者: ChienliMa's avatar ChienliMa

Small modification to copy(). Testcase is finished. Start debuging.

上级 9101711d
...@@ -573,11 +573,16 @@ class Function(object): ...@@ -573,11 +573,16 @@ class Function(object):
# copy fgraph and get memo # copy fgraph and get memo
maker = self.maker maker = self.maker
<<<<<<< HEAD
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# use copied ins, outs and fgraph to init a maker # use copied ins, outs and fgraph to init a maker
new_maker = maker.__class__(inputs=ins, outputs=outs, mode=maker.mode, new_maker = maker.__class__(inputs=ins, outputs=outs, mode=maker.mode,
fgraph=fg_cpy, profile=maker.profile, fgraph=fg_cpy, profile=maker.profile,
=======
new_maker = FunctionMaker(inputs=ins, outputs=outs, mode=maker.mode,
fgraph=new_fgraph, profile=maker.profile,
>>>>>>> Small modification to copy(). Testcase is finished. Start debuging.
accept_inplace=maker.accept_inplace, accept_inplace=maker.accept_inplace,
function_builder=maker.function_builder, function_builder=maker.function_builder,
on_unused_input=maker.on_unused_input) on_unused_input=maker.on_unused_input)
...@@ -588,6 +593,7 @@ class Function(object): ...@@ -588,6 +593,7 @@ class Function(object):
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
for key in storage_map.keys(): for key in storage_map.keys():
# output_storages should not be shared # output_storages should not be shared
<<<<<<< HEAD
if key not in self.maker.fgraph.outputs and memo.has_key(key): if key not in self.maker.fgraph.outputs and memo.has_key(key):
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
...@@ -603,6 +609,23 @@ class Function(object): ...@@ -603,6 +609,23 @@ class Function(object):
new_func = new_maker.create(input_storage, storage_map=new_storage_map) new_func = new_maker.create(input_storage, storage_map=new_storage_map)
=======
if key not in self.maker.fgraph.outputs and \
memo.has_key(key):
new_storage_map[memo[key]] = storage_map[key]
# copy input storages and link function with new storage_map
input_storage = copy.copy([getattr(i, 'value', None) for i in ins])
new_func = new_maker.create(input_storage, storage_map=new_storage_map)
# share immutable SharedVariable's storage
for (input, _1, _2), here, there in zip(self.indices,
self.input_storage,
new_func.input_storage):
if not input.mutable:
there.data = here.data
>>>>>>> Small modification to copy(). Testcase is finished. Start debuging.
return new_func return new_func
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论