提交 3055ab4d authored 作者: ChienliMa's avatar ChienliMa

finish swapSV and test

上级 af1fb499
...@@ -575,7 +575,7 @@ class Function(object): ...@@ -575,7 +575,7 @@ class Function(object):
# swap SharedVariable if need # swap SharedVariable if need
if swap != None: if swap != None:
self._swapSV(swap, ins, fg_cpy, memo) self.__swapSV(swap, ins, fg_cpy, memo)
#used to prevent #used to prevent
swapped_sv = swap.keys() swapped_sv = swap.keys()
...@@ -620,17 +620,19 @@ class Function(object): ...@@ -620,17 +620,19 @@ class Function(object):
self.input_storage, self.input_storage,
f_cpy.input_storage): f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant) is_const = isinstance(in_ori.variable, theano.tensor.Constant)
if is_const or not in_ori.mutable: swapped = swap != None and in_ori.name in swapped_sv
if (is_const or not in_ori.mutable) and not swapped:
cpy.data = ori.data cpy.data = ori.data
in_cpy.value = in_ori.value in_cpy.value = in_ori.value
# swap SharedVariable in In instances # swap SharedVariable in In instances
if in_cpy.variable.name in swap.keys(): if swapped and in_cpy.variable.name in swap.keys():
in_cpy.variable = swap[in_cpy.variable.name] in_cpy.variable = swap[in_cpy.variable.name]
return f_cpy return f_cpy
def _swapSV(self, swap, ins, fg_cpy, memo): def __swapSV(self, swap, ins, fg_cpy, memo):
""" """
Hidden auxiliary function, used to swap SharedVariable in maker.inputs Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph. and maker.fgraph.
...@@ -667,7 +669,6 @@ class Function(object): ...@@ -667,7 +669,6 @@ class Function(object):
# we don't use the originally defined variable # we don't use the originally defined variable
sv = sv.clone() sv = sv.clone()
# replace SharedVariable in maker's In instances # replace SharedVariable in maker's In instances
for i in ins: for i in ins:
if i.variable.name == name: if i.variable.name == name:
...@@ -675,16 +676,16 @@ class Function(object): ...@@ -675,16 +676,16 @@ class Function(object):
i.variable, i.value = sv, sv.container i.variable, i.value = sv, sv.container
break break
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(var_ori, theano.compile.SharedVariable) and var_ori.name == name:
checkSV(var_cpy, sv)
# replace variable in fgraph
# modify fg_cpy.iputs # modify fg_cpy.iputs
for i in xrange(len(fg_cpy.inputs)): for i in xrange(len(fg_cpy.inputs)):
if fg_cpy.inputs[i].name == name: if fg_cpy.inputs[i].name == name:
fg_cpy.inputs[i] = sv fg_cpy.inputs[i] = sv
break break
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(var_ori, theano.compile.SharedVariable) and var_ori.name == name:
checkSV(var_cpy, sv)
fg_cpy.replace(var_cpy, sv) fg_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var # modify memo so that ori_var->this shared var
memo[var_ori] = sv memo[var_ori] = sv
......
...@@ -287,21 +287,36 @@ class T_function(unittest.TestCase): ...@@ -287,21 +287,36 @@ class T_function(unittest.TestCase):
z = theano.shared(value=2, name='z') z = theano.shared(value=2, name='z')
# SharedVariable to replace # SharedVariable to replace
y_rpl = theano.shared(value=3) y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4) z_rpl = theano.shared(value=4, name='z_rpl')
swap = {'y':y_rpl, 'z':z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = T.tanh((x+y+2)/(x+z-0.2)**2) out = x+y+z
# Test for different linkers # Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]: for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1}) ori = theano.function([x], [out], mode=mode,updates={z:z+1})
cpy = ori.copy(swap={'y':y_rpl, 'z':z_rpl}) cpy = ori.copy(swap=swap)
# test what: # run fuction several time
# 1. is swapped( value equal, use = or is? ) ori(1), cpy(1),cpy(2)
# 2. is updatable( run several time and check value)
# 3. is/isn't separated in two function # test cpy function:
# 4. consistence in In and Variable # 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
names = map_SV.keys()
for key in cpy.fn.storage_map:
if key.name in names:
assert map_SV[key.name].container.storage[0] ==\
cpy.fn.storage_map[key][0]
if key.name == 'z_rpl' and not second_time:
assert cpy.fn.storage_map[key][0] == 6
second_time = True
elif key.name == 'z_rpl' and second_time:
assert cpy.fn.storage_map[key][0] == 8
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论