提交 a8a12e8e authored 作者: ChienliMa's avatar ChienliMa

Allow to swap SharedVariable by SharedVariable Instances and update corresponding docs.

上级 52201224
...@@ -547,10 +547,6 @@ class Function(object): ...@@ -547,10 +547,6 @@ class Function(object):
Copy this function. Copied function will have separated maker and Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate fgraph with original function. User can choose whether to separate
storage by changing the share_memory arguments. storage by changing the share_memory arguments.
Note:
Originally, variables in In and Out instances are those variables
that defined by user. After being copied, variables in In/Outs are
those variables in Function.maker.fgraph.
--------------------- ---------------------
Params: Params:
share_memory -- { boolean } Default is False. When True, two share_memory -- { boolean } Default is False. When True, two
...@@ -559,11 +555,11 @@ class Function(object): ...@@ -559,11 +555,11 @@ class Function(object):
storages and same maker. If two functions share memory and storages and same maker. If two functions share memory and
allow_gc=False, this will increase executing speed and save memory. allow_gc=False, this will increase executing speed and save memory.
swap -- { dict } Dictionary<String, theano.SharedVariable> that swap -- { dict } Dictionary that map old SharedVariables to new
map old SharedVariable's name to new SharedVariable. Default is SharedVariables. Default is None. The computational relationship is
None. The computational relationship is modified within the inner modified within the inner fgraph, especially for this copied
fgraph, especially for this copied function. SharedVariables aren't function. SharedVariables aren't swapped in the relationship that
swapped in the relationship that user defined. user defined.
delete_updates -- { boolean } Default is False. If True, Copied delete_updates -- { boolean } Default is False. If True, Copied
function will not have update. function will not have update.
...@@ -571,11 +567,27 @@ class Function(object): ...@@ -571,11 +567,27 @@ class Function(object):
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
""" """
maker = self.maker # helper function
# Copy Ins, so that they have different storage as their value def checkSV(sv_ori, sv_rpl):
ins = copy.deepcopy(maker.inputs) """
Assert two SharedVariable follow some restirctions:
1. same type
2. same shape or dim?
"""
assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with origianl one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type)
# Delete update output in fgraph and updates In instancesis needed maker = self.maker
# Copy Ins and their storage.
# so that they have different storage as their value
ins = [copy.copy(input) for input in maker.inputs]
for in_ori, in_cpy in zip(ins, maker.inputs):
in_ori.value = copy.deepcopy(in_cpy.value)
# Delete update output in fgraph and updates In instances if needed
if delete_updates: if delete_updates:
# The first len(maker.outputs) variabels are original variables. # The first len(maker.outputs) variabels are original variables.
# The rest are the updates. # The rest are the updates.
...@@ -590,11 +602,10 @@ class Function(object): ...@@ -590,11 +602,10 @@ class Function(object):
[memo[o] for o in out_vars], [memo[o] for o in out_vars],
clone=False) clone=False)
# Swap varaible in Outs. # Re initialize Outs and swap update and variable in Ins
outs = map(SymbolicOutput, maker.fgraph.outputs[:len(maker.outputs)])
# Swap update and variable in Ins
# By doing this, we can pass FunctionMaker._check_unused_inputs() # By doing this, we can pass FunctionMaker._check_unused_inputs()
outs = map(SymbolicOutput, fg_cpy.outputs[:len(maker.outputs)])
update_i = len(outs) update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs): for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var i.variable = in_var
...@@ -607,33 +618,26 @@ class Function(object): ...@@ -607,33 +618,26 @@ class Function(object):
# swap SharedVariable # swap SharedVariable
if swap is not None: if swap is not None:
def checkSV(sv_ori, sv_rpl): swap_svs_ori = swap.keys()
"""
Assert two SharedVariable follow some restirctions: # Check if given ShareVariables exist
1. same type for sv in swap_svs_ori:
2. same shape or dim? exist_svs = [i.variable for i in maker.inputs]
""" if sv not in exist_svs:
assert sv_ori.type == sv_rpl.type, ( warnings.warn("SharedVairable: %s not found" % (sv.name))
"Type of given SharedVariable conflicts with origianl one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type)
exist_names = [i.variable.name for i in ins]
swap_names = swap.keys()
# Check if given names exist
for name in swap_names:
if name not in exist_names:
warnings.warn("Given name: %s wasn't found" % (name))
# Swap SharedVairable in fgraph and ins # Swap SharedVairable in fgraph and ins
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)): for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
assert i.variable.name == in_v.name # Variables in maker.ipnuts are defined by user, therefore we
# use them to make comparision and do the mapping.
if in_v.name in swap_names: # Otherwise we don't touch them.
checkSV(in_v, swap[in_v.name]) swap_sv = maker.inputs[index].variable
if swap_sv in swap_svs_ori:
checkSV(i.variable, swap_sv)
# In the fgraph we use the cloned SharedVariable # In the fgraph we use the cloned SharedVariable
swap_sv = swap[in_v.name].clone() swap_sv = swap_sv.clone()
# Swap SharedVariable in fgraph # Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv fg_cpy.inputs[index] = swap_sv
...@@ -669,26 +673,29 @@ class Function(object): ...@@ -669,26 +673,29 @@ class Function(object):
accept_inplace=maker.accept_inplace).create( accept_inplace=maker.accept_inplace).create(
input_storage, storage_map=new_storage_map) input_storage, storage_map=new_storage_map)
# Share immutable and constant input's storage
for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs, for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs,
self.input_storage, self.input_storage,
f_cpy.input_storage): f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant)
# In instances' name default to vairables' name # Share immutable ShareVariable and constant input's storage
swapped = swap is not None and in_ori.name in swap_names swapped = swap is not None and in_ori.variable in swap_svs_ori
is_const = isinstance(in_ori.variable, theano.tensor.Constant)
if (is_const or not in_ori.mutable) and not swapped: if (is_const or not in_ori.mutable) and not swapped:
cpy.data = ori.data cpy.data = ori.data
in_cpy.value = in_ori.value in_cpy.value = in_ori.value
# Reconstruct Function.finder. # Reconstruct Function.finder which map Variable defined by user
# Function.value and Function.data work # to container, to make Function.value and Function.data work well.
for ori, cpy in zip(maker.inputs, f_cpy.maker.inputs): # Replace variable in new maker.inputs by the original ones.
swapped = swap is not None and ori.name in swap_names # So that user can swap SharedVariable in a swapped function
container = f_cpy.finder.pop(in_cpy.variable)
if not swapped: if not swapped:
f_cpy.finder[ori.variable] = f_cpy.finder.pop(cpy.variable) f_cpy.finder[in_ori.variable] = container
in_cpy.vairable = in_ori.variable
else: else:
f_cpy.finder[swap[ori.name]] = f_cpy.finder.pop(cpy.variable) f_cpy.finder[swap[in_ori.variable]] = container
in_cpy.variable = swap[in_ori.variable]
return f_cpy return f_cpy
......
...@@ -289,7 +289,7 @@ class T_function(unittest.TestCase): ...@@ -289,7 +289,7 @@ class T_function(unittest.TestCase):
# SharedVariable to replace # SharedVariable to replace
y_rpl = theano.shared(value=3,name ='y_rpl') y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4, name='z_rpl') z_rpl = theano.shared(value=4, name='z_rpl')
swap = {'y':y_rpl, 'z':z_rpl} swap = {y:y_rpl, y:z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl} map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = x+y+z out = x+y+z
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论