提交 809e6adb authored 作者: ChienliMa's avatar ChienliMa

modify test to assert SharedVariables are shared

上级 35677c96
......@@ -273,6 +273,7 @@ class T_function(unittest.TestCase):
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
m = theano.shared(value=0, name='m')
# SharedVariable to replace
y_rpl = theano.shared(value=3,name ='y_rpl')
......@@ -280,18 +281,24 @@ class T_function(unittest.TestCase):
swap = {y:y_rpl, y:z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = x+y+z
out = x+y+z+m
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1})
ori = theano.function([x], [out], mode=mode,updates=[(z,z+1),(m,m+2)])
cpy = ori.copy(swap=swap)
# run fuction several time
ori(1), cpy(1),cpy(2)
# assert same SharedVariable are update in different function
if not second_time:
assert m.get_value() == 6
elif second_time:
assert m.get_value() == 12
# test cpy function:
# 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
......@@ -302,9 +309,9 @@ class T_function(unittest.TestCase):
cpy.fn.storage_map[key][0]
if key.name == 'z_rpl' and not second_time:
assert cpy.fn.storage_map[key][0] == 6
second_time = True
elif key.name == 'z_rpl' and second_time:
assert cpy.fn.storage_map[key][0] == 8
second_time = True
def test_copy_delete_updates(self):
x = T.fscalar('x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论