提交 d71a80f6 authored 作者: ChienliMa's avatar ChienliMa

add test for sharedvar

上级 51ff4318
...@@ -637,22 +637,22 @@ class Function(object): ...@@ -637,22 +637,22 @@ class Function(object):
# Variables in maker.inputs are defined by user, therefore we # Variables in maker.inputs are defined by user, therefore we
# use them to make comparision and do the mapping. # use them to make comparision and do the mapping.
# Otherwise we don't touch them. # Otherwise we don't touch them.
swap_sv = maker.inputs[index].variable var = maker.inputs[index].variable
if swap_sv in swap_svs_ori: if var in swap_svs_ori:
swap_sv = swap[var]
checkSV(i.variable, swap_sv) checkSV(i.variable, swap_sv)
# swap variable and value of In instances
i.variable = swap_sv
# In the fgraph we use the cloned SharedVariable # In the fgraph we use the cloned SharedVariable
swap_sv = swap_sv.clone() swap_sv = swap_sv.clone()
# Swap SharedVariable in fgraph # Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv # fg_cpy.inputs[index] = swap_sv
fg_cpy.replace(in_v, swap_sv, reason="Swap SV") fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# swap variable and value of In instances
i.variable = swap_sv
i.value = swap_sv.container
# Construct new storage_map that map new variable to old storage, # Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one # so that the ensuing function shares storage with the original one
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
......
...@@ -284,7 +284,7 @@ class T_function(unittest.TestCase): ...@@ -284,7 +284,7 @@ class T_function(unittest.TestCase):
# SharedVariable to replace # SharedVariable to replace
y_rpl = theano.shared(value=3,name ='y_rpl') y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4, name='z_rpl') z_rpl = theano.shared(value=4, name='z_rpl')
swap = {y:y_rpl, y:z_rpl} swap = {y:y_rpl, z:z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl} map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = x+y+z+m out = x+y+z+m
...@@ -301,9 +301,21 @@ class T_function(unittest.TestCase): ...@@ -301,9 +301,21 @@ class T_function(unittest.TestCase):
# assert same SharedVariable are update in different function # assert same SharedVariable are update in different function
if not second_time: if not second_time:
# m should be updated 3 times
assert m.get_value() == 6 assert m.get_value() == 6
# z should be updated once
assert z.get_value() == 3
# z_rpl should be updated twice
assert z_rpl.get_value() == 6
# y and y_rpl should not be updated
assert y_rpl.get_value() == 3
assert y.get_value() == 1
elif second_time: elif second_time:
# doule update for sharedvariable
assert m.get_value() == 12 assert m.get_value() == 12
assert z.get_value() == 3
assert z_rpl.get_value() == 8
assert y_rpl.get_value() == 3
# test cpy function: # test cpy function:
# 2. SharedVariable is updatable -> values did update(z == 5) # 2. SharedVariable is updatable -> values did update(z == 5)
...@@ -313,10 +325,7 @@ class T_function(unittest.TestCase): ...@@ -313,10 +325,7 @@ class T_function(unittest.TestCase):
if key.name in names: if key.name in names:
assert map_SV[key.name].container.storage[0] ==\ assert map_SV[key.name].container.storage[0] ==\
cpy.fn.storage_map[key][0] cpy.fn.storage_map[key][0]
if key.name == 'z_rpl' and not second_time:
assert cpy.fn.storage_map[key][0] == 6
elif key.name == 'z_rpl' and second_time:
assert cpy.fn.storage_map[key][0] == 8
second_time = True second_time = True
def test_copy_delete_updates(self): def test_copy_delete_updates(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论