提交 7b8c07ab authored 作者: carriepl's avatar carriepl

Add new tests for In wrapped

上级 388d68c1
......@@ -34,22 +34,35 @@ class TestFunctionIn(unittest.TestCase):
def test_in_strict(self):
a = theano.tensor.dvector()
b = theano.shared(7)
out = a + b
f = theano.function([In(a, strict=False)], out)
# works, rand generates float64 by default
f(numpy.random.rand(8))
# works, casting is allowed
a = theano.tensor.dvector()
b = theano.shared(7)
out = a + b
f = theano.function([In(a, strict=False)], out)
# works, rand generates float64 by default
f(numpy.random.rand(8))
# works, casting is allowed
f(numpy.array([1, 2, 3, 4], dtype='int32'))
f = theano.function([In(a, strict=True)], out)
try:
# fails, f expects float64
f(numpy.array([1, 2, 3, 4], dtype='int32'))
f = theano.function([In(a, strict=True)], out)
try:
# fails, f expects float64
f(numpy.array([1, 2, 3, 4], dtype='int32'))
except TypeError:
pass
except TypeError:
pass
def test_explicit_shared_input(self):
# This is not a test of the In class per se, but the In class relies
# on the fact that shared variables cannot be explicit inputs
a = theano.shared(1.0)
self.assertRaises(TypeError, theano.function, [a], a + 1)
def test_in_shared_variable(self):
# Ensure that an error is raised if the In wrapped is used to wrap
# a shared variable
a = theano.shared(1.0)
a_wrapped = In(a, update=a+1)
self.assertRaises(TypeError, theano.function, [a_wrapped])
def test_in_mutable(self):
a = theano.tensor.dvector()
......@@ -71,7 +84,8 @@ class TestFunctionIn(unittest.TestCase):
def test_in_update(self):
a = theano.tensor.dscalar('a')
f = theano.function([In(a, value=0.0, update=a+1)], a, mode='FAST_RUN')
f = theano.function([In(a, value=0.0, update=a + 1)], a,
mode='FAST_RUN')
# Ensure that, through the executions of the function, the state of the
# input is persistent and is updated as it should
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论