提交 aaac6d22 authored 作者: ChienliMa's avatar ChienliMa

merge function.copy() with function.__copy__() and fix #3049

上级 be58f8b4
...@@ -539,22 +539,13 @@ class Function(object): ...@@ -539,22 +539,13 @@ class Function(object):
Copy a function. Copied function have separate intermediate Copy a function. Copied function have separate intermediate
storages and output storages with original function storages and output storages with original function
""" """
defaults = [default for _1, _2, default in self.defaults] return self.copy()
cpy = self.maker.create(defaults, trustme=True)
for (input, _1, _2), here, there in zip(self.indices,
self.input_storage,
cpy.input_storage):
if input.mutable and here is not None:
there.data = copy.copy(here.data)
else:
there.data = here.data
return cpy
def copy(self, share_memory=False): def copy(self, share_memory=False):
""" """
Copy this function. Copied function will have separated maker and Copy this function. Copied function will have separated maker and
fgraph with original function if share_memory=True. User can choose fgraph with original function. User can choose whether to separate
whether to separate storage by changing the share_memory arguments storage by changing the share_memory arguments
--------------------- ---------------------
Params: Params:
share_memory -- { boolean } Default is False. When True, two share_memory -- { boolean } Default is False. When True, two
...@@ -567,62 +558,60 @@ class Function(object): ...@@ -567,62 +558,60 @@ class Function(object):
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
""" """
if not share_memory: maker = self.maker
return self.__copy__() # copy Ins, so that they have different storage as their value
else: ins = copy.deepcopy(maker.inputs)
maker = self.maker
# copy Ins, so that they have different storage as their value # copy fgraph and get memo
ins = copy.deepcopy(maker.inputs) fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# copy fgraph and get memo storage_map = self.fn.storage_map
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) new_storage_map = {}
# If share_memory, Construct new storage_map that map new variable
# Construct new storage_map that map new variable to old storage, # to old storage, so that the ensuing function shares storage with
# so that the ensuing function shares storage with the original one # the original one.
# TODO: We could share the output storage, but we must make sure # TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This # 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to # is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too. # use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure # But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map. # it is well tested, we don't share the part of the storage_map.
storage_map = self.fn.storage_map if share_memory:
new_storage_map = {} i_o_vars = maker.fgraph.outputs + maker.fgraph.inputs
for key in storage_map.keys(): for key in storage_map.keys():
if key not in maker.fgraph.outputs: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
input_storage = [] input_storage = []
assert len(ins) == len(fg_cpy.inputs) assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs): for in_ori, in_cpy in zip(maker.inputs, ins):
# Since we reuse original Out instances, copied In instances # Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and # should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function # updates. Otherwise the compilation will fail at function
# FunctionMaker._check_unused_inputs() # FunctionMaker._check_unused_inputs()
in_cpy.variable = in_ori.variable in_cpy.variable = in_ori.variable
in_cpy.update = in_ori.update in_cpy.update = in_ori.update
input_storage.append(in_cpy.value)
# share input storages if it's immutable
is_const = isinstance(in_ori.variable, theano.tensor.Constant) # reinitialize new maker and create new function
if is_const or not in_ori.mutable: f_cpy = maker.__class__(inputs=ins, outputs=maker.outputs,
storage = getattr(in_ori, 'value', None) fgraph=fg_cpy,
in_cpy.value = in_ori.value mode=maker.mode, profile=maker.profile,
else: on_unused_input=maker.on_unused_input,
storage = getattr(in_cpy, 'value', None) function_builder=maker.function_builder,
input_storage.append(storage) accept_inplace=maker.accept_inplace).create(
input_storage, storage_map=new_storage_map)
# pop out input_storage in storage_map and only use storage of
# In instances to initialize the make, to avoid storage # Share immutable and constant input storage
# conflictions in link.map_storage() for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs,
new_storage_map.pop(in_v) self.input_storage,
f_cpy.input_storage):
# reinitialize new maker and create new function is_const = isinstance(in_ori.variable, theano.tensor.Constant)
return maker.__class__(inputs=ins, outputs=maker.outputs, if is_const or not in_ori.mutable:
fgraph=fg_cpy, cpy.data = ori.data
mode=maker.mode, profile=maker.profile, in_cpy.value = in_ori.value
on_unused_input=maker.on_unused_input,
function_builder=maker.function_builder, return f_cpy
accept_inplace=maker.accept_inplace).create(
input_storage, storage_map=new_storage_map)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
profile = self.profile profile = self.profile
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论