提交 aaac6d22 authored 作者: ChienliMa's avatar ChienliMa

merge function.copy() with function.__copy__() and fix #3049

上级 be58f8b4
......@@ -539,22 +539,13 @@ class Function(object):
Copy a function. Copied function have separate intermediate
storages and output storages with original function
"""
defaults = [default for _1, _2, default in self.defaults]
cpy = self.maker.create(defaults, trustme=True)
for (input, _1, _2), here, there in zip(self.indices,
self.input_storage,
cpy.input_storage):
if input.mutable and here is not None:
there.data = copy.copy(here.data)
else:
there.data = here.data
return cpy
return self.copy()
def copy(self, share_memory=False):
"""
Copy this function. Copied function will have separated maker and
fgraph with original function if share_memory=True. User can choose
whether to separate storage by changing the share_memory arguments
fgraph with original function. User can choose whether to separate
storage by changing the share_memory arguments
---------------------
Params:
share_memory -- { boolean } Default is False. When True, two
......@@ -567,9 +558,6 @@ class Function(object):
Returns:
func -- Copied theano.Function
"""
if not share_memory:
return self.__copy__()
else:
maker = self.maker
# copy Ins, so that they have different storage as their value
ins = copy.deepcopy(maker.inputs)
......@@ -577,46 +565,36 @@ class Function(object):
# copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
storage_map = self.fn.storage_map
new_storage_map = {}
# If share_memory, Construct new storage_map that map new variable
# to old storage, so that the ensuing function shares storage with
# the original one.
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
storage_map = self.fn.storage_map
new_storage_map = {}
if share_memory:
i_o_vars = maker.fgraph.outputs + maker.fgraph.inputs
for key in storage_map.keys():
if key not in maker.fgraph.outputs:
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
input_storage = []
assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
for in_ori, in_cpy in zip(maker.inputs, ins):
# Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function
# FunctionMaker._check_unused_inputs()
in_cpy.variable = in_ori.variable
in_cpy.update = in_ori.update
# share input storages if it's immutable
is_const = isinstance(in_ori.variable, theano.tensor.Constant)
if is_const or not in_ori.mutable:
storage = getattr(in_ori, 'value', None)
in_cpy.value = in_ori.value
else:
storage = getattr(in_cpy, 'value', None)
input_storage.append(storage)
# pop out input_storage in storage_map and only use storage of
# In instances to initialize the make, to avoid storage
# conflictions in link.map_storage()
new_storage_map.pop(in_v)
input_storage.append(in_cpy.value)
# reinitialize new maker and create new function
return maker.__class__(inputs=ins, outputs=maker.outputs,
f_cpy = maker.__class__(inputs=ins, outputs=maker.outputs,
fgraph=fg_cpy,
mode=maker.mode, profile=maker.profile,
on_unused_input=maker.on_unused_input,
......@@ -624,6 +602,17 @@ class Function(object):
accept_inplace=maker.accept_inplace).create(
input_storage, storage_map=new_storage_map)
# Share immutable and constant input storage
for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs,
self.input_storage,
f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant)
if is_const or not in_ori.mutable:
cpy.data = ori.data
in_cpy.value = in_ori.value
return f_cpy
def __call__(self, *args, **kwargs):
profile = self.profile
t0 = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论