提交 a98e4311 authored 作者: James Bergstra's avatar James Bergstra

Modified test_basic and test_nnet to work with floatX=float32

上级 ef7193a9
......@@ -871,7 +871,7 @@ class T_subtensor(unittest.TestCase):
print gval
good = numpy.zeros_like(data)
good[1,0] = numpy.exp(data[1,0])
self.failUnless(numpy.all(gval == good), (gval, good))
self.failUnless(numpy.allclose(gval, good), (gval, good))
class T_Join_and_Split(unittest.TestCase):
......@@ -1992,7 +1992,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_1_1(self):
"""Test PermuteRowElements(vector, vector)"""
input = vector()
input = dvector()
p = ivector()
out = permute_row_elements(input, p)
permute = function([input, p], out)
......@@ -2014,7 +2014,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_2_1(self):
"""Test broadcasting in PermuteRowElements(matrix, vector)"""
input = matrix()
input = dmatrix()
p = ivector()
out = permute_row_elements(input, p)
permute = function([input, p], out)
......@@ -2036,7 +2036,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_2_2(self):
"""Test PermuteRowElements(matrix, matrix)"""
input = matrix()
input = dmatrix()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
......@@ -2060,7 +2060,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_1_2(self):
"""Test PermuteRowElements(vector, matrix)
Different permutations will be applied to the same input vector"""
input = vector()
input = dvector()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
......@@ -2228,7 +2228,7 @@ def test_sum_overflow():
assert f([1]*300) == 300
def test_default():
x, y = dscalars('xy')
x, y = scalars('xy')
z = default(x, y)
f = function([x, y], z)
assert f(1, 2) == 1
......@@ -2236,14 +2236,17 @@ def test_default():
assert f(1, None) == 1
def test_default_state():
x, y = dscalars('xy')
x, y = scalars('xy')
print config.floatX
print x.type
print y.type
z = default(x, 3.8)
new_x = y + z
f = function([y, compile.In(x, update = new_x, value = 12.0)], new_x)
assert f(3) == 15
f['x'] = None
assert f(1) == 4.8
assert f(2.2) == 7
assert numpy.allclose(f(1), 4.8)
assert numpy.allclose(f(2.2), 7)
def test_autocast():
orig_autocast = autocast_float.dtypes
......
......@@ -104,7 +104,6 @@ class T_prepend(unittest.TestCase):
f=theano.function([x],y)
m=numpy.ones((3,5),dtype="float32")
my = f(m)
self.failUnless(str(my.dtype) == 'float64')
self.failUnless(my.shape == (3, 6))
self.failUnless(numpy.all(my[:,0] == 5.0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论