提交 cc3fe3ca authored 作者: Frederic Bastien's avatar Frederic Bastien

added test for the new autocast for numpy.int{8,16,32} type.

上级 a4646f2c
......@@ -117,6 +117,18 @@ class Test_SharedVariable(unittest.TestCase):
assert b.type == theano.tensor.lscalar
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.int32(7), strict=True)
assert b.type == theano.tensor.iscalar
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.int16(7), strict=True)
assert b.type == theano.tensor.wscalar
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.int8(7), strict=True)
assert b.type == theano.tensor.bscalar
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.float64(7.234), strict=True)
assert b.type == theano.tensor.dscalar
self.failUnlessRaises(TypeError, f, b, 8)
......@@ -145,6 +157,18 @@ class Test_SharedVariable(unittest.TestCase):
assert b.type == theano.tensor.lvector
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.int32([7]), strict=True)
assert b.type == theano.tensor.ivector
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.int16([7]), strict=True)
assert b.type == theano.tensor.wvector
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.int8([7]), strict=True)
assert b.type == theano.tensor.bvector
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.float64([7.234]), strict=True)
assert b.type == theano.tensor.dvector
self.failUnlessRaises(TypeError, f, b, 8)
......@@ -181,22 +205,42 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.int64(7))
assert b.type == theano.tensor.lscalar
f(b,8.23)
assert b.value==8
b = shared(numpy.int32(7))
assert b.type == theano.tensor.iscalar
f(b,8.23)
assert b.value==8
b = shared(numpy.int16(7))
assert b.type == theano.tensor.wscalar
f(b,8.23)
assert b.value==8
b = shared(numpy.int8(7))
assert b.type == theano.tensor.bscalar
f(b,8.23)
assert b.value==8
b = shared(numpy.float64(7.234))
assert b.type == theano.tensor.dscalar
f(b,8)
assert b.value==8
b = shared(numpy.float32(7.234))
assert b.type == theano.tensor.fscalar
f(b,8)
assert b.value==8
b = shared(numpy.float(7.234))
assert b.type == theano.tensor.dscalar
f(b,8)
assert b.value==8
b = shared(7.234)
assert b.type == theano.tensor.dscalar
f(b,8)
assert b.value==8
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f, b, numpy.random.rand(5,5))
......@@ -209,14 +253,32 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.int64([7]))
assert b.type == theano.tensor.lvector
f(b,[8.23])
assert b.value == 8
b = shared(numpy.int32([7]))
assert b.type == theano.tensor.ivector
f(b,[8.23])
assert b.value == 8
b = shared(numpy.int16([7]))
assert b.type == theano.tensor.wvector
f(b,[8.23])
assert b.value == 8
b = shared(numpy.int8([7]))
assert b.type == theano.tensor.bvector
f(b,[8.23])
assert b.value == 8
b = shared(numpy.float64([7.234]))
assert b.type == theano.tensor.dvector
f(b,[8])
assert b.value == 8
b = shared(numpy.float32([7.234]))
assert b.type == theano.tensor.fvector
f(b,[8])
assert b.value == 8
#numpy.float([7.234]) don't work
# b = shared(numpy.float([7.234]))
......@@ -231,6 +293,7 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.asarray([7.234],dtype=theano.config.floatX))
assert b.dtype == theano.config.floatX
f(b,[8])
assert b.value == 8
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f, b, numpy.random.rand(5,5))
......
......@@ -2504,6 +2504,13 @@ def test_autocast():
assert (fvector()+ numpy.float32(1.1)).dtype == 'float32'
assert (fvector()+ numpy.float64(1.1)).dtype == 'float64'
assert (fvector()+ numpy.float(1.1)).dtype == theano.config.floatX
assert (lvector()+ numpy.int64(1)).dtype == 'int64'
assert (lvector()+ numpy.int32(1)).dtype == 'int64'
assert (lvector()+ numpy.int16(1)).dtype == 'int64'
assert (lvector()+ numpy.int8(1)).dtype == 'int64'
assert (ivector()+ numpy.int8(1)).dtype == 'int32'
assert (wvector()+ numpy.int8(1)).dtype == 'int16'
assert (bvector()+ numpy.int8(1)).dtype == 'int8'
try: #ghetto 2.4 version of with
ac2 = autocast_float_as('float64')
ac2.__enter__()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论