提交 e2a47f19 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Tests for the new default_update mechanism of shared variables.

上级 24dc3617
......@@ -227,10 +227,220 @@ class Test_pfunc(unittest.TestCase):
assert numpy.all(y.value==24)
assert numpy.all(z.value==24**2)
def test_default_updates(self):
x = shared(0)
x.default_update = x+1
f = pfunc([], [x])
f()
print x.value
assert x.value == 1
del x.default_update
f()
assert x.value == 2
g = pfunc([], [x])
g()
assert x.value == 2
def test_no_default_updates(self):
x = shared(0)
y = shared(1)
x.default_update = x+2
# Test that the default update is taken into account in the right cases
f1 = pfunc([], [x], no_default_updates=True)
f1()
print x.value
assert x.value == 0
f2 = pfunc([], [x], no_default_updates=[x])
f2()
print x.value
assert x.value == 0
f3 = pfunc([], [x], no_default_updates=[x, y])
f3()
print x.value
assert x.value == 0
f4 = pfunc([], [x], no_default_updates=[y])
f4()
print x.value
assert x.value == 2
f5 = pfunc([], [x], no_default_updates=[])
f5()
print x.value
assert x.value == 4
f5 = pfunc([], [x], no_default_updates=False)
f5()
print x.value
assert x.value == 6
self.failUnlessRaises(TypeError, pfunc, [], [x], no_default_updates=(x))
self.failUnlessRaises(TypeError, pfunc, [], [x], no_default_updates=x)
self.failUnlessRaises(TypeError, pfunc, [], [x], no_default_updates='canard')
# Mix explicit updates and no_default_updates
g1 = pfunc([], [x], updates=[(x,x-1)], no_default_updates=True)
g1()
print x.value
assert x.value == 5
g2 = pfunc([], [x], updates=[(x,x-1)], no_default_updates=[x])
g2()
print x.value
assert x.value == 4
g3 = pfunc([], [x], updates=[(x,x-1)], no_default_updates=[x, y])
g3()
print x.value
assert x.value == 3
g4 = pfunc([], [x], updates=[(x,x-1)], no_default_updates=[y])
g4()
print x.value
assert x.value == 2
g5 = pfunc([], [x], updates=[(x,x-1)], no_default_updates=[])
g5()
print x.value
assert x.value == 1
g5 = pfunc([], [x], updates=[(x,x-1)], no_default_updates=False)
g5()
print x.value
assert x.value == 0
def test_default_updates_expressions(self):
x = shared(0)
y = shared(1)
a = lscalar('a')
z = a*x
x.default_update = x+y
f1 = pfunc([a], z)
f1(12)
print x
assert x.value == 1
f2 = pfunc([a], z, no_default_updates=True)
assert f2(7) == 7
print x
assert x.value == 1
f3 = pfunc([a], z, no_default_updates=[x])
assert f3(9) == 9
print x
assert x.value == 1
def test_default_updates_multiple(self):
x = shared(0)
y = shared(1)
x.default_update = x-1
y.default_update = y+1
f1 = pfunc([], [x,y])
f1()
assert x.value == -1
assert y.value == 2
f2 = pfunc([], [x,y], updates=[(x, x-2)], no_default_updates=[y])
f2()
assert x.value == -3
assert y.value == 2
f3 = pfunc([], [x,y], updates=[(x, x-2)], no_default_updates=True)
f3()
assert x.value == -5
assert y.value == 2
f4 = pfunc([], [x,y], updates=[(y, y-2)])
f4()
assert x.value == -6
assert y.value == 0
def test_default_updates_chained(self):
x = shared(2)
y = shared(1)
z = shared(-1)
x.default_update = x-y
y.default_update = z
z.default_update = z-1
f1 = pfunc([], [x])
f1()
print x.value, y.value, z.value
assert x.value == 1
assert y.value == -1
assert z.value == -2
f2 = pfunc([], [x, y])
f2()
assert x.value == 2
assert y.value == -2
assert z.value == -3
f3 = pfunc([], [y])
f3()
assert x.value == 2
assert y.value == -3
assert z.value == -4
f4 = pfunc([], [x,y], no_default_updates=[x])
f4()
assert x.value == 2
assert y.value == -4
assert z.value == -5
f5 = pfunc([], [x,y,z], no_default_updates=[z])
f5()
assert x.value == 6
assert y.value == -5
assert z.value == -5
def test_default_updates_input(self):
x = shared(0)
y = shared(1)
a = lscalar('a')
x.default_update = y
y.default_update = y+a
f1 = pfunc([], x, no_default_updates=True)
f1()
assert x.value == 0
assert y.value == 1
f2 = pfunc([], x, no_default_updates=[x])
f2()
assert x.value == 0
assert y.value == 1
f3 = pfunc([], x, no_default_updates=[y])
f3()
assert x.value == 1
assert y.value == 1
f4 = pfunc([a], x)
f4(2)
assert x.value == 1
assert y.value == 3
f5 = pfunc([], x, updates={y:y-1})
f5()
assert x.value == 3
assert y.value == 2
# a is needed as input if y.default_update is used
self.failUnlessRaises(TypeError, pfunc, [], x)
if __name__ == '__main__':
theano.compile.mode.default_mode = 'FAST_COMPILE'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论