提交 7b8c07ab authored 作者: carriepl's avatar carriepl

Add new tests for In wrapped

上级 388d68c1
...@@ -34,22 +34,35 @@ class TestFunctionIn(unittest.TestCase): ...@@ -34,22 +34,35 @@ class TestFunctionIn(unittest.TestCase):
def test_in_strict(self): def test_in_strict(self):
a = theano.tensor.dvector() a = theano.tensor.dvector()
b = theano.shared(7) b = theano.shared(7)
out = a + b out = a + b
f = theano.function([In(a, strict=False)], out) f = theano.function([In(a, strict=False)], out)
# works, rand generates float64 by default # works, rand generates float64 by default
f(numpy.random.rand(8)) f(numpy.random.rand(8))
# works, casting is allowed # works, casting is allowed
f(numpy.array([1, 2, 3, 4], dtype='int32'))
f = theano.function([In(a, strict=True)], out)
try:
# fails, f expects float64
f(numpy.array([1, 2, 3, 4], dtype='int32')) f(numpy.array([1, 2, 3, 4], dtype='int32'))
except TypeError:
f = theano.function([In(a, strict=True)], out) pass
try:
# fails, f expects float64 def test_explicit_shared_input(self):
f(numpy.array([1, 2, 3, 4], dtype='int32')) # This is not a test of the In class per se, but the In class relies
except TypeError: # on the fact that shared variables cannot be explicit inputs
pass a = theano.shared(1.0)
self.assertRaises(TypeError, theano.function, [a], a + 1)
def test_in_shared_variable(self):
# Ensure that an error is raised if the In wrapped is used to wrap
# a shared variable
a = theano.shared(1.0)
a_wrapped = In(a, update=a+1)
self.assertRaises(TypeError, theano.function, [a_wrapped])
def test_in_mutable(self): def test_in_mutable(self):
a = theano.tensor.dvector() a = theano.tensor.dvector()
...@@ -71,7 +84,8 @@ class TestFunctionIn(unittest.TestCase): ...@@ -71,7 +84,8 @@ class TestFunctionIn(unittest.TestCase):
def test_in_update(self): def test_in_update(self):
a = theano.tensor.dscalar('a') a = theano.tensor.dscalar('a')
f = theano.function([In(a, value=0.0, update=a+1)], a, mode='FAST_RUN') f = theano.function([In(a, value=0.0, update=a + 1)], a,
mode='FAST_RUN')
# Ensure that, through the executions of the function, the state of the # Ensure that, through the executions of the function, the state of the
# input is persistent and is updated as it should # input is persistent and is updated as it should
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论