提交 78dc12fb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Check Python long are OK iff they fit in int64

上级 f462a1c0
...@@ -6175,7 +6175,7 @@ def _test_autocast_numpy(): ...@@ -6175,7 +6175,7 @@ def _test_autocast_numpy():
def ok(z): def ok(z):
assert tensor.constant(z).dtype == numpy.asarray(z).dtype assert tensor.constant(z).dtype == numpy.asarray(z).dtype
for x in ([2 ** i for i in xrange(63)] + for x in ([2 ** i for i in xrange(63)] +
[0] + [0, 0L, 1L, 2L ** 63 - 1] +
[0., 1., 1.1, 1.5]): [0., 1., 1.1, 1.5]):
n_x = numpy.asarray(x) n_x = numpy.asarray(x)
# Make sure the data type is the same as the one found by numpy. # Make sure the data type is the same as the one found by numpy.
...@@ -6204,11 +6204,11 @@ def _test_autocast_numpy_floatX(): ...@@ -6204,11 +6204,11 @@ def _test_autocast_numpy_floatX():
for floatX in ('float32', 'float64'): for floatX in ('float32', 'float64'):
config.floatX = floatX config.floatX = floatX
# Go through some typical scalar values. # Go through some typical scalar values.
# Note that we only consider integer values that Python considers # We only consider 'int' and 'long' Python values that can fit
# to be 'int', because 'long' is not supported by Theano (due to # into int64, as that is the maximal integer type that Theano
# the fact it is unbounded). # supports, and that is the maximal type in Python indexing.
for x in ([2 ** i for i in xrange(64) if type(2 ** i) == int] + for x in ([2 ** i - 1 for i in xrange(64)] +
[0] + [0, 0L, 1L, 2L ** 63 - 1] +
[0., 1., 1.1, 1.5]): [0., 1., 1.1, 1.5]):
ok(x, floatX) ok(x, floatX)
ok(-x, floatX) ok(-x, floatX)
...@@ -6372,6 +6372,29 @@ class test_arithmetic_cast(unittest.TestCase): ...@@ -6372,6 +6372,29 @@ class test_arithmetic_cast(unittest.TestCase):
category=DeprecationWarning) category=DeprecationWarning)
class T_long_tensor(unittest.TestCase):
def test_fit_int64(self):
for exp in xrange(64):
val = 2L ** exp - 1
scalar_ct = constant(val)
assert scalar_ct.dtype == 'int64'
assert scalar_ct.value == val
vector_ct = constant([val, val])
assert vector_ct.dtype == 'int64'
assert numpy.all(vector_ct.value == val)
matrix_ct = constant([[val, val]])
assert matrix_ct.dtype == 'int64'
assert numpy.all(matrix_ct.value == val)
def test_too_big(self):
val = 2L ** 63
self.assertRaises(Exception, constant, val)
self.assertRaises(Exception, constant, [val, val])
self.assertRaises(Exception, constant, [[val, val]])
class test_broadcast(unittest.TestCase): class test_broadcast(unittest.TestCase):
def test_broadcast_bigdim(self): def test_broadcast_bigdim(self):
def f(): def f():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论