提交 1ec1cd9b authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3465 from abergeron/fix_arange

Fix arange to avoid overflow.
......@@ -4824,6 +4824,12 @@ def arange(start, stop=None, step=1, dtype=None):
# If dtype is not provided, infer it from the other arguments
if dtype is None:
dtype = scal.upcast(start.type.dtype, stop.type.dtype, step.type.dtype)
# don't try to be stingy and byte-optimize, this leads to
# overflow problems.
if dtype.startswith('int'):
dtype = 'int64'
if dtype.startswith('uint'):
dtype = 'uint64'
if config.cast_policy in ('numpy', 'numpy+floatX'):
# We enforce numpy semantics, except in the special case where
# `config.cast_policy` is 'numpy+floatX' and we want to use float32
......
......@@ -5316,7 +5316,7 @@ class TestARange(unittest.TestCase):
f = function([start, stop, step], out)
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(numpy.array(1, dtype='int32')).dtype
assert out.dtype == numpy_dtype
......@@ -5393,7 +5393,7 @@ class TestARange(unittest.TestCase):
f = function([start, stop], out)
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange(numpy.int32(0),
numpy.int32(1)).dtype
......@@ -5421,7 +5421,7 @@ class TestARange(unittest.TestCase):
f = function([stop], out)
if config.cast_policy == 'custom':
assert out.dtype == stop.type.dtype
assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange(numpy.int32(1)).dtype
else:
......@@ -5453,7 +5453,7 @@ class TestARange(unittest.TestCase):
def test_upcast(self):
"""Test that arange computes output type adequately"""
if config.cast_policy == 'custom':
assert arange(iscalar()).dtype == iscalar().dtype
assert arange(iscalar()).dtype == 'int64'
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
......@@ -5547,7 +5547,7 @@ class TestARange(unittest.TestCase):
assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(numpy.array(0, dtype=start.dtype),
numpy.array(1, dtype=stop.dtype),
......@@ -5568,7 +5568,7 @@ class TestARange(unittest.TestCase):
assert len(f.maker.fgraph.toposort()) == 5
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange(
numpy.int32(0), numpy.int32(1), numpy.int32(1)).dtype
......@@ -5590,7 +5590,7 @@ class TestARange(unittest.TestCase):
#[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(0,
numpy.array(1, dtype=stop.dtype),
......
......@@ -342,7 +342,7 @@ def test_scan_debugprint2():
| |Subtensor{int64} [@J] ''
| |Shape [@K] ''
| | |Subtensor{int64::} [@L] ''
| | |ARange{dtype='int16'} [@M] ''
| | |ARange{dtype='int64'} [@M] ''
| | | |TensorConstant{0} [@N]
| | | |TensorConstant{10000} [@O]
| | | |TensorConstant{1} [@P]
......@@ -366,7 +366,7 @@ def test_scan_debugprint2():
> |coefficients[t] [@Y] -> [@S]
> |Elemwise{pow,no_inplace} [@Z] ''
> |x_copy [@BA] -> [@W]
> |<TensorType(int16, scalar)> [@BB] -> [@U]"""
> |<TensorType(int64, scalar)> [@BB] -> [@U]"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
......@@ -425,7 +425,7 @@ def test_scan_debugprint3():
| |Subtensor{int64} [@J] ''
| |Shape [@K] ''
| | |Subtensor{int64::} [@L] ''
| | |ARange{dtype='int8'} [@M] ''
| | |ARange{dtype='int64'} [@M] ''
| | | |TensorConstant{0} [@N]
| | | |TensorConstant{10} [@O]
| | | |TensorConstant{1} [@P]
......@@ -479,7 +479,7 @@ def test_scan_debugprint3():
> | | |Constant{1} [@BX]
> | |Constant{-1} [@BY]
> |DimShuffle{x} [@BZ] ''
> |<TensorType(int8, scalar)> [@CA] -> [@U]
> |<TensorType(int64, scalar)> [@CA] -> [@U]
for{cpu,scan_fn} [@BE] ''
>Elemwise{mul,no_inplace} [@CB] ''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论