提交 7b8c07ab authored 作者: carriepl's avatar carriepl

Add new tests for In wrapped

上级 388d68c1
...@@ -51,6 +51,19 @@ class TestFunctionIn(unittest.TestCase): ...@@ -51,6 +51,19 @@ class TestFunctionIn(unittest.TestCase):
except TypeError: except TypeError:
pass pass
def test_explicit_shared_input(self):
# This is not a test of the In class per se, but the In class relies
# on the fact that shared variables cannot be explicit inputs
a = theano.shared(1.0)
self.assertRaises(TypeError, theano.function, [a], a + 1)
def test_in_shared_variable(self):
# Ensure that an error is raised if the In wrapped is used to wrap
# a shared variable
a = theano.shared(1.0)
a_wrapped = In(a, update=a+1)
self.assertRaises(TypeError, theano.function, [a_wrapped])
def test_in_mutable(self): def test_in_mutable(self):
a = theano.tensor.dvector() a = theano.tensor.dvector()
a_out = a * 2 # assuming the op which makes this "in place" triggers a_out = a * 2 # assuming the op which makes this "in place" triggers
...@@ -71,7 +84,8 @@ class TestFunctionIn(unittest.TestCase): ...@@ -71,7 +84,8 @@ class TestFunctionIn(unittest.TestCase):
def test_in_update(self): def test_in_update(self):
a = theano.tensor.dscalar('a') a = theano.tensor.dscalar('a')
f = theano.function([In(a, value=0.0, update=a+1)], a, mode='FAST_RUN') f = theano.function([In(a, value=0.0, update=a + 1)], a,
mode='FAST_RUN')
# Ensure that, through the executions of the function, the state of the # Ensure that, through the executions of the function, the state of the
# input is persistent and is updated as it should # input is persistent and is updated as it should
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论