提交 3055ab4d authored 作者: ChienliMa's avatar ChienliMa

finish swapSV and test

上级 af1fb499
......@@ -575,7 +575,7 @@ class Function(object):
# swap SharedVariable if need
if swap != None:
self._swapSV(swap, ins, fg_cpy, memo)
self.__swapSV(swap, ins, fg_cpy, memo)
#used to prevent
swapped_sv = swap.keys()
......@@ -620,17 +620,19 @@ class Function(object):
self.input_storage,
f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant)
if is_const or not in_ori.mutable:
swapped = swap != None and in_ori.name in swapped_sv
if (is_const or not in_ori.mutable) and not swapped:
cpy.data = ori.data
in_cpy.value = in_ori.value
# swap SharedVariable in In instances
if in_cpy.variable.name in swap.keys():
if swapped and in_cpy.variable.name in swap.keys():
in_cpy.variable = swap[in_cpy.variable.name]
return f_cpy
def _swapSV(self, swap, ins, fg_cpy, memo):
def __swapSV(self, swap, ins, fg_cpy, memo):
"""
Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph.
......@@ -667,7 +669,6 @@ class Function(object):
# we don't use the originally defined variable
sv = sv.clone()
# replace SharedVariable in maker's In instances
for i in ins:
if i.variable.name == name:
......@@ -675,16 +676,16 @@ class Function(object):
i.variable, i.value = sv, sv.container
break
# modify fg_cpy.iputs
for i in xrange(len(fg_cpy.inputs)):
if fg_cpy.inputs[i].name == name:
fg_cpy.inputs[i] = sv
break
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(var_ori, theano.compile.SharedVariable) and var_ori.name == name:
checkSV(var_cpy, sv)
# replace variable in fgraph
# modify fg_cpy.iputs
for i in xrange(len(fg_cpy.inputs)):
if fg_cpy.inputs[i].name == name:
fg_cpy.inputs[i] = sv
break
fg_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var
memo[var_ori] = sv
......
......@@ -287,21 +287,36 @@ class T_function(unittest.TestCase):
z = theano.shared(value=2, name='z')
# SharedVariable to replace
y_rpl = theano.shared(value=3)
z_rpl = theano.shared(value=4)
y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4, name='z_rpl')
swap = {'y':y_rpl, 'z':z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = T.tanh((x+y+2)/(x+z-0.2)**2)
out = x+y+z
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1})
cpy = ori.copy(swap={'y':y_rpl, 'z':z_rpl})
# test what:
# 1. is swapped( value equal, use = or is? )
# 2. is updatable( run several time and check value)
# 3. is/isn't separated in two function
# 4. consistence in In and Variable
cpy = ori.copy(swap=swap)
# run fuction several time
ori(1), cpy(1),cpy(2)
# test cpy function:
# 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
names = map_SV.keys()
for key in cpy.fn.storage_map:
if key.name in names:
assert map_SV[key.name].container.storage[0] ==\
cpy.fn.storage_map[key][0]
if key.name == 'z_rpl' and not second_time:
assert cpy.fn.storage_map[key][0] == 6
second_time = True
elif key.name == 'z_rpl' and second_time:
assert cpy.fn.storage_map[key][0] == 8
def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论