提交 cd15ea08 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More thorough tests for compute_test_value.

上级 5c907045
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
import theano import theano
from theano import tensor as T from theano import tensor as T
from theano.tensor.basic import _allclose
class TestComputeTestValue(unittest.TestCase): class TestComputeTestValue(unittest.TestCase):
...@@ -19,6 +20,10 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -19,6 +20,10 @@ class TestComputeTestValue(unittest.TestCase):
# should work # should work
z = T.dot(x,y) z = T.dot(x,y)
assert hasattr(z.tag, 'test_value')
f = theano.function([x,y], z)
assert _allclose(f(x.tag.test_value, y.tag.test_value),
z.tag.test_value)
# this test should fail # this test should fail
y.tag.test_value = numpy.random.rand(6,5) y.tag.test_value = numpy.random.rand(6,5)
...@@ -37,6 +42,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -37,6 +42,7 @@ class TestComputeTestValue(unittest.TestCase):
# should skip computation of test value # should skip computation of test value
theano.config.compute_test_value = False theano.config.compute_test_value = False
z = T.dot(x,y) z = T.dot(x,y)
assert not hasattr(z.tag, 'test_value')
# should fail one or another when flag is set # should fail one or another when flag is set
theano.config.compute_test_value = 'warn' theano.config.compute_test_value = 'warn'
...@@ -60,6 +66,12 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -60,6 +66,12 @@ class TestComputeTestValue(unittest.TestCase):
# should work # should work
out = T.dot(T.dot(x,y), z) out = T.dot(T.dot(x,y), z)
assert hasattr(out.tag, 'test_value')
tf = theano.function([x,y], out)
assert _allclose(
tf(x.tag.test_value, y.tag.test_value),
out.tag.test_value)
def f(x,y,z): def f(x,y,z):
return T.dot(T.dot(x,y),z) return T.dot(T.dot(x,y),z)
...@@ -80,6 +92,9 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -80,6 +92,9 @@ class TestComputeTestValue(unittest.TestCase):
# should work # should work
z = T.dot(x,y) z = T.dot(x,y)
assert hasattr(z.tag, 'test_value')
f = theano.function([x], z)
assert _allclose(f(x.tag.test_value), z.tag.test_value)
# this test should fail # this test should fail
y.set_value(numpy.random.rand(5,6)) y.set_value(numpy.random.rand(5,6))
...@@ -97,6 +112,9 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -97,6 +112,9 @@ class TestComputeTestValue(unittest.TestCase):
# should work # should work
z = T.dot(x,y) z = T.dot(x,y)
assert hasattr(z.tag, 'test_value')
f = theano.function([], z)
assert _allclose(f(), z.tag.test_value)
# this test should fail # this test should fail
x = numpy.random.rand(2,4) x = numpy.random.rand(2,4)
...@@ -114,6 +132,9 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -114,6 +132,9 @@ class TestComputeTestValue(unittest.TestCase):
# should work # should work
z = T.dot(x,y) z = T.dot(x,y)
assert hasattr(z.tag, 'test_value')
f = theano.function([], z)
assert _allclose(f(), z.tag.test_value)
# this test should fail # this test should fail
x = T.constant(numpy.random.rand(2,4)) x = T.constant(numpy.random.rand(2,4))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论