提交 253dde5c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix compute_test_values test in float32, ensure test_value has the right type.

上级 b58395fc
......@@ -337,7 +337,8 @@ class PureOp(object):
elif isinstance(ins,SharedVariable):
input_vals.append(ins.get_value(borrow=True))
elif isinstance(ins,graph.Variable) and hasattr(ins.tag, 'test_value'):
input_vals.append(ins.tag.test_value)
# ensure that the test value is correct
input_vals.append(ins.type.filter(ins.tag.test_value))
else:
# no test-value was specified, act accordingly
if config.compute_test_value == 'warn':
......
......@@ -2,6 +2,7 @@ import numpy
import unittest
import theano
from theano import config
from theano import tensor as T
from theano.tensor.basic import _allclose
......@@ -14,9 +15,9 @@ class TestComputeTestValue(unittest.TestCase):
theano.config.compute_test_value = 'True'
x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4)
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
y = T.matrix('y')
y.tag.test_value = numpy.random.rand(4,5)
y.tag.test_value = numpy.random.rand(4,5).astype(config.floatX)
# should work
z = T.dot(x,y)
......@@ -26,7 +27,7 @@ class TestComputeTestValue(unittest.TestCase):
z.tag.test_value)
# this test should fail
y.tag.test_value = numpy.random.rand(6,5)
y.tag.test_value = numpy.random.rand(6,5).astype(config.floatX)
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
......@@ -37,7 +38,7 @@ class TestComputeTestValue(unittest.TestCase):
try:
x = T.matrix('x')
y = T.matrix('y')
y.tag.test_value = numpy.random.rand(4,5)
y.tag.test_value = numpy.random.rand(4,5).astype(config.floatX)
# should skip computation of test value
theano.config.compute_test_value = 'False'
......@@ -58,11 +59,11 @@ class TestComputeTestValue(unittest.TestCase):
theano.config.compute_test_value = 'True'
x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4)
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
y = T.matrix('y')
y.tag.test_value = numpy.random.rand(4,5)
y.tag.test_value = numpy.random.rand(4,5).astype(config.floatX)
z = theano.shared(numpy.random.rand(5,6))
z = theano.shared(numpy.random.rand(5,6).astype(config.floatX))
# should work
out = T.dot(T.dot(x,y), z)
......@@ -76,7 +77,7 @@ class TestComputeTestValue(unittest.TestCase):
return T.dot(T.dot(x,y),z)
# this test should fail
z.set_value(numpy.random.rand(7,6))
z.set_value(numpy.random.rand(7,6).astype(config.floatX))
self.assertRaises(ValueError, f, x, y, z)
finally:
theano.config.compute_test_value = orig_compute_test_value
......@@ -87,8 +88,8 @@ class TestComputeTestValue(unittest.TestCase):
theano.config.compute_test_value = 'True'
x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4)
y = theano.shared(numpy.random.rand(4,6), 'y')
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
y = theano.shared(numpy.random.rand(4,6).astype(config.floatX), 'y')
# should work
z = T.dot(x,y)
......@@ -97,7 +98,7 @@ class TestComputeTestValue(unittest.TestCase):
assert _allclose(f(x.tag.test_value), z.tag.test_value)
# this test should fail
y.set_value(numpy.random.rand(5,6))
y.set_value(numpy.random.rand(5,6).astype(config.floatX))
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
......@@ -107,8 +108,8 @@ class TestComputeTestValue(unittest.TestCase):
try:
theano.config.compute_test_value = 'True'
x = numpy.random.rand(2,3)
y = theano.shared(numpy.random.rand(3,6), 'y')
x = numpy.random.rand(2,3).astype(config.floatX)
y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y')
# should work
z = T.dot(x,y)
......@@ -117,7 +118,7 @@ class TestComputeTestValue(unittest.TestCase):
assert _allclose(f(), z.tag.test_value)
# this test should fail
x = numpy.random.rand(2,4)
x = numpy.random.rand(2,4).astype(config.floatX)
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
......@@ -127,8 +128,8 @@ class TestComputeTestValue(unittest.TestCase):
try:
theano.config.compute_test_value = 'True'
x = T.constant(numpy.random.rand(2,3))
y = theano.shared(numpy.random.rand(3,6), 'y')
x = T.constant(numpy.random.rand(2,3), dtype=config.floatX)
y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y')
# should work
z = T.dot(x,y)
......@@ -137,7 +138,22 @@ class TestComputeTestValue(unittest.TestCase):
assert _allclose(f(), z.tag.test_value)
# this test should fail
x = T.constant(numpy.random.rand(2,4))
x = T.constant(numpy.random.rand(2,4), dtype=config.floatX)
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_incorrect_type(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'True'
x = T.fmatrix('x')
# Incorrect dtype (float64) for test_value
x.tag.test_value = numpy.random.rand(3,4)
y = T.dmatrix('y')
y.tag.test_value = numpy.random.rand(4,5)
self.assertRaises(TypeError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论