提交 0c1cdd89 authored 作者: ChienliMa's avatar ChienliMa

Add feature 'delete_update'. Merge three features in one function. Add test cases.

上级 3055ab4d
...@@ -18,6 +18,7 @@ import theano ...@@ -18,6 +18,7 @@ import theano
from theano import gof from theano import gof
from functools import partial from functools import partial
from theano.compat import izip from theano.compat import izip
from theano.gof import graph
import theano.compile.mode import theano.compile.mode
from theano.compile.io import ( from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput) In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
...@@ -541,7 +542,7 @@ class Function(object): ...@@ -541,7 +542,7 @@ class Function(object):
""" """
return self.copy() return self.copy()
def copy(self, share_memory=False, swap=None): def copy(self, share_memory=False, swap=None, delete_updates=False):
""" """
Copy this function. Copied function will have separated maker and Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate fgraph with original function. User can choose whether to separate
...@@ -562,23 +563,54 @@ class Function(object): ...@@ -562,23 +563,54 @@ class Function(object):
swap -- { dict } Dictionary<String, theano.SharedVariable> that swap -- { dict } Dictionary<String, theano.SharedVariable> that
map old SharedVariable's name to new SharedVariable. Default is map old SharedVariable's name to new SharedVariable. Default is
None. None.
delete_updates -- { boolean } Default is False. If True, Copied
function will not have update.
--------------------- ---------------------
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
""" """
maker = self.maker maker = self.maker
# copy Ins, so that they have different storage as their value # Copy Ins, so that they have different storage as their value
ins = copy.deepcopy(maker.inputs) ins, outs = copy.deepcopy([maker.inputs, maker.outputs])
# copy fgraph and get memo # Delete update output in fgraph and updates In instancesis needed
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) if delete_updates:
# The first len(outs) variabels would be original variables.
# The rest variables in fgraph.outputs are the updates.
out_vars = maker.fgraph.outputs[:len(outs)]
else:
out_vars = maker.fgraph.outputs
# Init new fgraph using copied variables and get memo
# memo: a dict that map old variables to new variabls
memo = graph.clone_get_equiv(maker.fgraph.inputs, out_vars)
fg_cpy = gof.fg.FunctionGraph([memo[i] for i in maker.fgraph.inputs],
[memo[o] for o in out_vars],
clone=False)
# Swap varaible in Outs.
for i in xrange(len(outs)):
outs[i].variable = fg_cpy.outputs[i]
# Swap update and variable in Ins
# By doing this, we can pass FunctionMaker._check_unused_inputs()
update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var
if not delete_updates and i.update != None:
i.update = fg_cpy.outputs[update_i]
update_i += 1
else:
i.update = None
# swap SharedVariable if need # swap SharedVariable
if swap != None: if swap != None:
self.__swapSV(swap, ins, fg_cpy, memo) self.__swapSV(swap, ins, fg_cpy)
#used to prevent # the name of SV we swapped
swapped_sv = swap.keys() swapped_sv = swap.keys()
print (fg_cpy)
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
new_storage_map = {} new_storage_map = {}
# Construct new storage_map that map new variable to old storage, # Construct new storage_map that map new variable to old storage,
...@@ -595,19 +627,9 @@ class Function(object): ...@@ -595,19 +627,9 @@ class Function(object):
if key not in i_o_vars: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
input_storage = [] input_storage = [ i.value for i in ins ]
assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
# Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function
# FunctionMaker._check_unused_inputs()
in_cpy.variable = in_ori.variable
in_cpy.update = in_ori.update
input_storage.append(in_cpy.value)
# reinitialize new maker and create new function # reinitialize new maker and create new function
f_cpy = maker.__class__(inputs=ins, outputs=maker.outputs, f_cpy = maker.__class__(inputs=ins, outputs=outs,
fgraph=fg_cpy, fgraph=fg_cpy,
mode=maker.mode, profile=maker.profile, mode=maker.mode, profile=maker.profile,
on_unused_input=maker.on_unused_input, on_unused_input=maker.on_unused_input,
...@@ -620,19 +642,25 @@ class Function(object): ...@@ -620,19 +642,25 @@ class Function(object):
self.input_storage, self.input_storage,
f_cpy.input_storage): f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant) is_const = isinstance(in_ori.variable, theano.tensor.Constant)
# In instances' name default to vairables' name
swapped = swap != None and in_ori.name in swapped_sv swapped = swap != None and in_ori.name in swapped_sv
if (is_const or not in_ori.mutable) and not swapped: if (is_const or not in_ori.mutable) and not swapped:
cpy.data = ori.data cpy.data = ori.data
in_cpy.value = in_ori.value in_cpy.value = in_ori.value
# swap SharedVariable in In instances # Reconstruct Function.finder.
if swapped and in_cpy.variable.name in swap.keys(): # Function.value and Function.data work
in_cpy.variable = swap[in_cpy.variable.name] for ori, cpy in zip(maker.inputs, f_cpy.maker.inputs):
swapped = swap != None and ori.name in swapped_sv
if not swapped:
f_cpy.finder[ori.variable] = f_cpy.finder.pop(cpy.variable)
else:
f_cpy.finder[swap[ori.name]] = f_cpy.finder.pop(cpy.variable)
return f_cpy return f_cpy
def __swapSV(self, swap, ins, fg_cpy, memo): def __swapSV(self, swap, ins, fg_cpy):
""" """
Hidden auxiliary function, used to swap SharedVariable in maker.inputs Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph. and maker.fgraph.
...@@ -641,7 +669,6 @@ class Function(object): ...@@ -641,7 +669,6 @@ class Function(object):
swap -- { dict } The same as docs in copy() swap -- { dict } The same as docs in copy()
ins -- { list } List of In instances to be modified. ins -- { list } List of In instances to be modified.
fg_cpy -- { FounctionGraph } Copied FunctionGraph. fg_cpy -- { FounctionGraph } Copied FunctionGraph.
memo -- { dict } Dict that map old variables to new variables.
-------------------- --------------------
Returns: Returns:
None None
...@@ -657,39 +684,30 @@ class Function(object): ...@@ -657,39 +684,30 @@ class Function(object):
"Type of given SharedVariable:", sv_rpl.type, "Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type ) "Type of original SharedVariable:", sv_ori.type )
# Main code start here exist_names = [i.variable.name for i in ins]
exist_name = [ i.variable.name for i in ins ] swap_names = swap.keys()
for name, sv in swap.iteritems():
assert type(name) is str, ("The SharedVariable need to be replaced" # Check if given names exist
"should be indicated by its name(string). Given:",type(name)) for name in swap_names:
# check if SharedVariable exist if name not in exist_names:
if name not in exist_name: warnings.warn( "Given name: %s wasn't found" % (name) )
print ("WARNING: SharedVariable named: %s not found") % name
continue # Swap SharedVairable in fgraph and ins
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
# we don't use the originally defined variable assert i.variable.name == in_v.name
sv = sv.clone()
# replace SharedVariable in maker's In instances if in_v.name in swap_names:
for i in ins: # In the fgraph we use the cloned SharedVariable
if i.variable.name == name: swap_sv = swap[in_v.name].clone()
checkSV( i.variable, sv) checkSV( in_v, swap_sv)
i.variable, i.value = sv, sv.container
break # Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv
# modify fg_cpy.iputs fg_cpy.replace( in_v, swap_sv, reason="Swap SV")
for i in xrange(len(fg_cpy.inputs)):
if fg_cpy.inputs[i].name == name: # swap variable and value of In instances
fg_cpy.inputs[i] = sv i.variable = swap_sv
break i.value = swap_sv.container
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(var_ori, theano.compile.SharedVariable) and var_ori.name == name:
checkSV(var_cpy, sv)
fg_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var
memo[var_ori] = sv
break
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
profile = self.profile profile = self.profile
...@@ -1377,7 +1395,8 @@ class FunctionMaker(object): ...@@ -1377,7 +1395,8 @@ class FunctionMaker(object):
else: else:
# fgraph is already an optimized one # fgraph is already an optimized one
need_opt = False need_opt = False
_, additional_outputs = std_fgraph(inputs, outputs, accept_inplace) updates = [spec.update for spec in inputs if spec.update]
additional_outputs = map(SymbolicOutput, updates)
pass pass
self.fgraph = fgraph self.fgraph = fgraph
......
...@@ -318,6 +318,24 @@ class T_function(unittest.TestCase): ...@@ -318,6 +318,24 @@ class T_function(unittest.TestCase):
elif key.name == 'z_rpl' and second_time: elif key.name == 'z_rpl' and second_time:
assert cpy.fn.storage_map[key][0] == 8 assert cpy.fn.storage_map[key][0] == 8
def test_copy_delete_updates(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
out = x+y+z
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], out, mode=mode,updates={z:z*2})
cpy = ori.copy(delete_updates=True)
print cpy(1)
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs') x, s = T.scalars('xs')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论