提交 30556a7c authored 作者: Frederic's avatar Frederic

Fix test error with NumPy 1.7.1.

This is for python long that fit in uint64, but not in an int64. NumPy 1.7 raise an Overflow error, but NumPy 1.7.1 return an uint64.
上级 c1179017
...@@ -366,9 +366,14 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -366,9 +366,14 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
# Theano graph, because on Windows 64, all shapes are expressed # Theano graph, because on Windows 64, all shapes are expressed
# with longs. # with longs.
# If a long fits in int64, we convert it into an int64, like # If a long fits in int64, we convert it into an int64, like
# numpy.asarray() does. # numpy.asarray() does up to 1.7. NumPy 1.7.1 upcaset to int64
# if possible, but fallback to uint64 if int64 isn't possible but
# uint64 is. We always do as NumPy 1.7.1 here.
# If x is too big, an OverflowError will be raised by numpy. # If x is too big, an OverflowError will be raised by numpy.
x_ = theano._asarray(x, dtype='int64') try:
x_ = theano._asarray(x, dtype='int64')
except OverflowError:
x_ = theano._asarray(x, dtype='uint64')
elif isinstance(x, numpy.ndarray): elif isinstance(x, numpy.ndarray):
x_ = x x_ = x
# Currently we do not have a bool dtype in Theano. # Currently we do not have a bool dtype in Theano.
...@@ -377,6 +382,10 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -377,6 +382,10 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
if x.dtype == 'bool': if x.dtype == 'bool':
x_ = numpy.asarray(x_, dtype='uint8') x_ = numpy.asarray(x_, dtype='uint8')
else: else:
# Here x is probably a list or a tuple. If it contain a long,
# we will behave like the current NumPy version: 1.7 and bellow,
# it will only work if the long fit in int64. For NumPy 1.7.1+,
# it will work if the long git in int64 or uint64.
x_ = numpy.asarray(x) x_ = numpy.asarray(x)
assert type(x_) == numpy.ndarray assert type(x_) == numpy.ndarray
......
...@@ -6392,6 +6392,30 @@ class T_long_tensor(unittest.TestCase): ...@@ -6392,6 +6392,30 @@ class T_long_tensor(unittest.TestCase):
def test_too_big(self): def test_too_big(self):
val = 2L ** 63 val = 2L ** 63
#NumPy 1.7 this will raise an exception
#NumPy 1.7.1 this will work
try:
cst = constant(val)
assert cst.value == val
assert cst.dtype == "uint64"
except Exception:
pass
try:
cst = constant([val, val])
assert cst.value == val
assert cst.dtype == "uint64"
except Exception:
pass
try:
cst = constant([[val, val]])
assert cst.value == val
assert cst.dtype == "uint64"
except Exception:
pass
val = 2L ** 64
# This fail for all NumPy version.
self.assertRaises(Exception, constant, val) self.assertRaises(Exception, constant, val)
self.assertRaises(Exception, constant, [val, val]) self.assertRaises(Exception, constant, [val, val])
self.assertRaises(Exception, constant, [[val, val]]) self.assertRaises(Exception, constant, [[val, val]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论