提交 0c1cdd89 authored 作者: ChienliMa's avatar ChienliMa

Add feature 'delete_update'. Merge three features in one function. Add test cases.

上级 3055ab4d
......@@ -18,6 +18,7 @@ import theano
from theano import gof
from functools import partial
from theano.compat import izip
from theano.gof import graph
import theano.compile.mode
from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
......@@ -541,7 +542,7 @@ class Function(object):
"""
return self.copy()
def copy(self, share_memory=False, swap=None):
def copy(self, share_memory=False, swap=None, delete_updates=False):
"""
Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate
......@@ -562,23 +563,54 @@ class Function(object):
swap -- { dict } Dictionary<String, theano.SharedVariable> that
map old SharedVariable's name to new SharedVariable. Default is
None.
delete_updates -- { boolean } Default is False. If True, Copied
function will not have update.
---------------------
Returns:
func -- Copied theano.Function
"""
maker = self.maker
# copy Ins, so that they have different storage as their value
ins = copy.deepcopy(maker.inputs)
# copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# Copy Ins, so that they have different storage as their value
ins, outs = copy.deepcopy([maker.inputs, maker.outputs])
# Delete update output in fgraph and updates In instancesis needed
if delete_updates:
# The first len(outs) variabels would be original variables.
# The rest variables in fgraph.outputs are the updates.
out_vars = maker.fgraph.outputs[:len(outs)]
else:
out_vars = maker.fgraph.outputs
# Init new fgraph using copied variables and get memo
# memo: a dict that map old variables to new variabls
memo = graph.clone_get_equiv(maker.fgraph.inputs, out_vars)
fg_cpy = gof.fg.FunctionGraph([memo[i] for i in maker.fgraph.inputs],
[memo[o] for o in out_vars],
clone=False)
# Swap varaible in Outs.
for i in xrange(len(outs)):
outs[i].variable = fg_cpy.outputs[i]
# Swap update and variable in Ins
# By doing this, we can pass FunctionMaker._check_unused_inputs()
update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var
if not delete_updates and i.update != None:
i.update = fg_cpy.outputs[update_i]
update_i += 1
else:
i.update = None
# swap SharedVariable if need
# swap SharedVariable
if swap != None:
self.__swapSV(swap, ins, fg_cpy, memo)
#used to prevent
self.__swapSV(swap, ins, fg_cpy)
# the name of SV we swapped
swapped_sv = swap.keys()
print (fg_cpy)
storage_map = self.fn.storage_map
new_storage_map = {}
# Construct new storage_map that map new variable to old storage,
......@@ -595,19 +627,9 @@ class Function(object):
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
input_storage = []
assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
# Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function
# FunctionMaker._check_unused_inputs()
in_cpy.variable = in_ori.variable
in_cpy.update = in_ori.update
input_storage.append(in_cpy.value)
input_storage = [ i.value for i in ins ]
# reinitialize new maker and create new function
f_cpy = maker.__class__(inputs=ins, outputs=maker.outputs,
f_cpy = maker.__class__(inputs=ins, outputs=outs,
fgraph=fg_cpy,
mode=maker.mode, profile=maker.profile,
on_unused_input=maker.on_unused_input,
......@@ -620,19 +642,25 @@ class Function(object):
self.input_storage,
f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant)
# In instances' name default to vairables' name
swapped = swap != None and in_ori.name in swapped_sv
if (is_const or not in_ori.mutable) and not swapped:
cpy.data = ori.data
in_cpy.value = in_ori.value
# swap SharedVariable in In instances
if swapped and in_cpy.variable.name in swap.keys():
in_cpy.variable = swap[in_cpy.variable.name]
# Reconstruct Function.finder.
# Function.value and Function.data work
for ori, cpy in zip(maker.inputs, f_cpy.maker.inputs):
swapped = swap != None and ori.name in swapped_sv
if not swapped:
f_cpy.finder[ori.variable] = f_cpy.finder.pop(cpy.variable)
else:
f_cpy.finder[swap[ori.name]] = f_cpy.finder.pop(cpy.variable)
return f_cpy
def __swapSV(self, swap, ins, fg_cpy, memo):
def __swapSV(self, swap, ins, fg_cpy):
"""
Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph.
......@@ -641,7 +669,6 @@ class Function(object):
swap -- { dict } The same as docs in copy()
ins -- { list } List of In instances to be modified.
fg_cpy -- { FounctionGraph } Copied FunctionGraph.
memo -- { dict } Dict that map old variables to new variables.
--------------------
Returns:
None
......@@ -657,39 +684,30 @@ class Function(object):
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type )
# Main code start here
exist_name = [ i.variable.name for i in ins ]
for name, sv in swap.iteritems():
assert type(name) is str, ("The SharedVariable need to be replaced"
"should be indicated by its name(string). Given:",type(name))
# check if SharedVariable exist
if name not in exist_name:
print ("WARNING: SharedVariable named: %s not found") % name
continue
# we don't use the originally defined variable
sv = sv.clone()
# replace SharedVariable in maker's In instances
for i in ins:
if i.variable.name == name:
checkSV( i.variable, sv)
i.variable, i.value = sv, sv.container
break
# modify fg_cpy.iputs
for i in xrange(len(fg_cpy.inputs)):
if fg_cpy.inputs[i].name == name:
fg_cpy.inputs[i] = sv
break
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(var_ori, theano.compile.SharedVariable) and var_ori.name == name:
checkSV(var_cpy, sv)
fg_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var
memo[var_ori] = sv
break
exist_names = [i.variable.name for i in ins]
swap_names = swap.keys()
# Check if given names exist
for name in swap_names:
if name not in exist_names:
warnings.warn( "Given name: %s wasn't found" % (name) )
# Swap SharedVairable in fgraph and ins
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
assert i.variable.name == in_v.name
if in_v.name in swap_names:
# In the fgraph we use the cloned SharedVariable
swap_sv = swap[in_v.name].clone()
checkSV( in_v, swap_sv)
# Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv
fg_cpy.replace( in_v, swap_sv, reason="Swap SV")
# swap variable and value of In instances
i.variable = swap_sv
i.value = swap_sv.container
def __call__(self, *args, **kwargs):
profile = self.profile
......@@ -1377,7 +1395,8 @@ class FunctionMaker(object):
else:
# fgraph is already an optimized one
need_opt = False
_, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
updates = [spec.update for spec in inputs if spec.update]
additional_outputs = map(SymbolicOutput, updates)
pass
self.fgraph = fgraph
......
......@@ -318,6 +318,24 @@ class T_function(unittest.TestCase):
elif key.name == 'z_rpl' and second_time:
assert cpy.fn.storage_map[key][0] == 8
def test_copy_delete_updates(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
out = x+y+z
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], out, mode=mode,updates={z:z*2})
cpy = ori.copy(delete_updates=True)
print cpy(1)
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论