提交 a98e4311 authored 作者: James Bergstra's avatar James Bergstra

Modified test_basic and test_nnet to work with floatX=float32

上级 ef7193a9
...@@ -871,7 +871,7 @@ class T_subtensor(unittest.TestCase): ...@@ -871,7 +871,7 @@ class T_subtensor(unittest.TestCase):
print gval print gval
good = numpy.zeros_like(data) good = numpy.zeros_like(data)
good[1,0] = numpy.exp(data[1,0]) good[1,0] = numpy.exp(data[1,0])
self.failUnless(numpy.all(gval == good), (gval, good)) self.failUnless(numpy.allclose(gval, good), (gval, good))
class T_Join_and_Split(unittest.TestCase): class T_Join_and_Split(unittest.TestCase):
...@@ -1992,7 +1992,7 @@ class TestPermuteRowElements(unittest.TestCase): ...@@ -1992,7 +1992,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_1_1(self): def test_1_1(self):
"""Test PermuteRowElements(vector, vector)""" """Test PermuteRowElements(vector, vector)"""
input = vector() input = dvector()
p = ivector() p = ivector()
out = permute_row_elements(input, p) out = permute_row_elements(input, p)
permute = function([input, p], out) permute = function([input, p], out)
...@@ -2014,7 +2014,7 @@ class TestPermuteRowElements(unittest.TestCase): ...@@ -2014,7 +2014,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_2_1(self): def test_2_1(self):
"""Test broadcasting in PermuteRowElements(matrix, vector)""" """Test broadcasting in PermuteRowElements(matrix, vector)"""
input = matrix() input = dmatrix()
p = ivector() p = ivector()
out = permute_row_elements(input, p) out = permute_row_elements(input, p)
permute = function([input, p], out) permute = function([input, p], out)
...@@ -2036,7 +2036,7 @@ class TestPermuteRowElements(unittest.TestCase): ...@@ -2036,7 +2036,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_2_2(self): def test_2_2(self):
"""Test PermuteRowElements(matrix, matrix)""" """Test PermuteRowElements(matrix, matrix)"""
input = matrix() input = dmatrix()
p = imatrix() p = imatrix()
out = permute_row_elements(input, p) out = permute_row_elements(input, p)
permute = function([input, p], out) permute = function([input, p], out)
...@@ -2060,7 +2060,7 @@ class TestPermuteRowElements(unittest.TestCase): ...@@ -2060,7 +2060,7 @@ class TestPermuteRowElements(unittest.TestCase):
def test_1_2(self): def test_1_2(self):
"""Test PermuteRowElements(vector, matrix) """Test PermuteRowElements(vector, matrix)
Different permutations will be applied to the same input vector""" Different permutations will be applied to the same input vector"""
input = vector() input = dvector()
p = imatrix() p = imatrix()
out = permute_row_elements(input, p) out = permute_row_elements(input, p)
permute = function([input, p], out) permute = function([input, p], out)
...@@ -2228,7 +2228,7 @@ def test_sum_overflow(): ...@@ -2228,7 +2228,7 @@ def test_sum_overflow():
assert f([1]*300) == 300 assert f([1]*300) == 300
def test_default(): def test_default():
x, y = dscalars('xy') x, y = scalars('xy')
z = default(x, y) z = default(x, y)
f = function([x, y], z) f = function([x, y], z)
assert f(1, 2) == 1 assert f(1, 2) == 1
...@@ -2236,14 +2236,17 @@ def test_default(): ...@@ -2236,14 +2236,17 @@ def test_default():
assert f(1, None) == 1 assert f(1, None) == 1
def test_default_state(): def test_default_state():
x, y = dscalars('xy') x, y = scalars('xy')
print config.floatX
print x.type
print y.type
z = default(x, 3.8) z = default(x, 3.8)
new_x = y + z new_x = y + z
f = function([y, compile.In(x, update = new_x, value = 12.0)], new_x) f = function([y, compile.In(x, update = new_x, value = 12.0)], new_x)
assert f(3) == 15 assert f(3) == 15
f['x'] = None f['x'] = None
assert f(1) == 4.8 assert numpy.allclose(f(1), 4.8)
assert f(2.2) == 7 assert numpy.allclose(f(2.2), 7)
def test_autocast(): def test_autocast():
orig_autocast = autocast_float.dtypes orig_autocast = autocast_float.dtypes
......
...@@ -104,7 +104,6 @@ class T_prepend(unittest.TestCase): ...@@ -104,7 +104,6 @@ class T_prepend(unittest.TestCase):
f=theano.function([x],y) f=theano.function([x],y)
m=numpy.ones((3,5),dtype="float32") m=numpy.ones((3,5),dtype="float32")
my = f(m) my = f(m)
self.failUnless(str(my.dtype) == 'float64')
self.failUnless(my.shape == (3, 6)) self.failUnless(my.shape == (3, 6))
self.failUnless(numpy.all(my[:,0] == 5.0)) self.failUnless(numpy.all(my[:,0] == 5.0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论