提交 d71a80f6 authored 作者: ChienliMa's avatar ChienliMa

add test for sharedvar

上级 51ff4318
......@@ -637,22 +637,22 @@ class Function(object):
# Variables in maker.inputs are defined by user, therefore we
# use them to make comparision and do the mapping.
# Otherwise we don't touch them.
swap_sv = maker.inputs[index].variable
var = maker.inputs[index].variable
if swap_sv in swap_svs_ori:
if var in swap_svs_ori:
swap_sv = swap[var]
checkSV(i.variable, swap_sv)
# swap variable and value of In instances
i.variable = swap_sv
# In the fgraph we use the cloned SharedVariable
swap_sv = swap_sv.clone()
# Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv
# fg_cpy.inputs[index] = swap_sv
fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# swap variable and value of In instances
i.variable = swap_sv
i.value = swap_sv.container
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
storage_map = self.fn.storage_map
......
......@@ -284,7 +284,7 @@ class T_function(unittest.TestCase):
# SharedVariable to replace
y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4, name='z_rpl')
swap = {y:y_rpl, y:z_rpl}
swap = {y:y_rpl, z:z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = x+y+z+m
......@@ -301,9 +301,21 @@ class T_function(unittest.TestCase):
# assert same SharedVariable are update in different function
if not second_time:
# m should be updated 3 times
assert m.get_value() == 6
# z should be updated once
assert z.get_value() == 3
# z_rpl should be updated twice
assert z_rpl.get_value() == 6
# y and y_rpl should not be updated
assert y_rpl.get_value() == 3
assert y.get_value() == 1
elif second_time:
# doule update for sharedvariable
assert m.get_value() == 12
assert z.get_value() == 3
assert z_rpl.get_value() == 8
assert y_rpl.get_value() == 3
# test cpy function:
# 2. SharedVariable is updatable -> values did update(z == 5)
......@@ -313,10 +325,7 @@ class T_function(unittest.TestCase):
if key.name in names:
assert map_SV[key.name].container.storage[0] ==\
cpy.fn.storage_map[key][0]
if key.name == 'z_rpl' and not second_time:
assert cpy.fn.storage_map[key][0] == 6
elif key.name == 'z_rpl' and second_time:
assert cpy.fn.storage_map[key][0] == 8
second_time = True
def test_copy_delete_updates(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论