提交 1ec1cd9b authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3465 from abergeron/fix_arange

Fix arange to avoid overflow.
...@@ -4824,6 +4824,12 @@ def arange(start, stop=None, step=1, dtype=None): ...@@ -4824,6 +4824,12 @@ def arange(start, stop=None, step=1, dtype=None):
# If dtype is not provided, infer it from the other arguments # If dtype is not provided, infer it from the other arguments
if dtype is None: if dtype is None:
dtype = scal.upcast(start.type.dtype, stop.type.dtype, step.type.dtype) dtype = scal.upcast(start.type.dtype, stop.type.dtype, step.type.dtype)
# don't try to be stingy and byte-optimize, this leads to
# overflow problems.
if dtype.startswith('int'):
dtype = 'int64'
if dtype.startswith('uint'):
dtype = 'uint64'
if config.cast_policy in ('numpy', 'numpy+floatX'): if config.cast_policy in ('numpy', 'numpy+floatX'):
# We enforce numpy semantics, except in the special case where # We enforce numpy semantics, except in the special case where
# `config.cast_policy` is 'numpy+floatX' and we want to use float32 # `config.cast_policy` is 'numpy+floatX' and we want to use float32
......
...@@ -5316,7 +5316,7 @@ class TestARange(unittest.TestCase): ...@@ -5316,7 +5316,7 @@ class TestARange(unittest.TestCase):
f = function([start, stop, step], out) f = function([start, stop, step], out)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(numpy.array(1, dtype='int32')).dtype numpy_dtype = numpy.arange(numpy.array(1, dtype='int32')).dtype
assert out.dtype == numpy_dtype assert out.dtype == numpy_dtype
...@@ -5393,7 +5393,7 @@ class TestARange(unittest.TestCase): ...@@ -5393,7 +5393,7 @@ class TestARange(unittest.TestCase):
f = function([start, stop], out) f = function([start, stop], out)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange(numpy.int32(0), assert out.dtype == numpy.arange(numpy.int32(0),
numpy.int32(1)).dtype numpy.int32(1)).dtype
...@@ -5421,7 +5421,7 @@ class TestARange(unittest.TestCase): ...@@ -5421,7 +5421,7 @@ class TestARange(unittest.TestCase):
f = function([stop], out) f = function([stop], out)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == stop.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange(numpy.int32(1)).dtype assert out.dtype == numpy.arange(numpy.int32(1)).dtype
else: else:
...@@ -5453,7 +5453,7 @@ class TestARange(unittest.TestCase): ...@@ -5453,7 +5453,7 @@ class TestARange(unittest.TestCase):
def test_upcast(self): def test_upcast(self):
"""Test that arange computes output type adequately""" """Test that arange computes output type adequately"""
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert arange(iscalar()).dtype == iscalar().dtype assert arange(iscalar()).dtype == 'int64'
assert arange(fscalar()).dtype == fscalar().dtype assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype assert arange(dscalar()).dtype == dscalar().dtype
...@@ -5547,7 +5547,7 @@ class TestARange(unittest.TestCase): ...@@ -5547,7 +5547,7 @@ class TestARange(unittest.TestCase):
assert len(f.maker.fgraph.toposort()) == 9 assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(numpy.array(0, dtype=start.dtype), numpy_dtype = numpy.arange(numpy.array(0, dtype=start.dtype),
numpy.array(1, dtype=stop.dtype), numpy.array(1, dtype=stop.dtype),
...@@ -5568,7 +5568,7 @@ class TestARange(unittest.TestCase): ...@@ -5568,7 +5568,7 @@ class TestARange(unittest.TestCase):
assert len(f.maker.fgraph.toposort()) == 5 assert len(f.maker.fgraph.toposort()) == 5
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] # 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange( assert out.dtype == numpy.arange(
numpy.int32(0), numpy.int32(1), numpy.int32(1)).dtype numpy.int32(0), numpy.int32(1), numpy.int32(1)).dtype
...@@ -5590,7 +5590,7 @@ class TestARange(unittest.TestCase): ...@@ -5590,7 +5590,7 @@ class TestARange(unittest.TestCase):
#[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)] #[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(0, numpy_dtype = numpy.arange(0,
numpy.array(1, dtype=stop.dtype), numpy.array(1, dtype=stop.dtype),
......
...@@ -342,7 +342,7 @@ def test_scan_debugprint2(): ...@@ -342,7 +342,7 @@ def test_scan_debugprint2():
| |Subtensor{int64} [@J] '' | |Subtensor{int64} [@J] ''
| |Shape [@K] '' | |Shape [@K] ''
| | |Subtensor{int64::} [@L] '' | | |Subtensor{int64::} [@L] ''
| | |ARange{dtype='int16'} [@M] '' | | |ARange{dtype='int64'} [@M] ''
| | | |TensorConstant{0} [@N] | | | |TensorConstant{0} [@N]
| | | |TensorConstant{10000} [@O] | | | |TensorConstant{10000} [@O]
| | | |TensorConstant{1} [@P] | | | |TensorConstant{1} [@P]
...@@ -366,7 +366,7 @@ def test_scan_debugprint2(): ...@@ -366,7 +366,7 @@ def test_scan_debugprint2():
> |coefficients[t] [@Y] -> [@S] > |coefficients[t] [@Y] -> [@S]
> |Elemwise{pow,no_inplace} [@Z] '' > |Elemwise{pow,no_inplace} [@Z] ''
> |x_copy [@BA] -> [@W] > |x_copy [@BA] -> [@W]
> |<TensorType(int16, scalar)> [@BB] -> [@U]""" > |<TensorType(int64, scalar)> [@BB] -> [@U]"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
...@@ -425,7 +425,7 @@ def test_scan_debugprint3(): ...@@ -425,7 +425,7 @@ def test_scan_debugprint3():
| |Subtensor{int64} [@J] '' | |Subtensor{int64} [@J] ''
| |Shape [@K] '' | |Shape [@K] ''
| | |Subtensor{int64::} [@L] '' | | |Subtensor{int64::} [@L] ''
| | |ARange{dtype='int8'} [@M] '' | | |ARange{dtype='int64'} [@M] ''
| | | |TensorConstant{0} [@N] | | | |TensorConstant{0} [@N]
| | | |TensorConstant{10} [@O] | | | |TensorConstant{10} [@O]
| | | |TensorConstant{1} [@P] | | | |TensorConstant{1} [@P]
...@@ -479,7 +479,7 @@ def test_scan_debugprint3(): ...@@ -479,7 +479,7 @@ def test_scan_debugprint3():
> | | |Constant{1} [@BX] > | | |Constant{1} [@BX]
> | |Constant{-1} [@BY] > | |Constant{-1} [@BY]
> |DimShuffle{x} [@BZ] '' > |DimShuffle{x} [@BZ] ''
> |<TensorType(int8, scalar)> [@CA] -> [@U] > |<TensorType(int64, scalar)> [@CA] -> [@U]
for{cpu,scan_fn} [@BE] '' for{cpu,scan_fn} [@BE] ''
>Elemwise{mul,no_inplace} [@CB] '' >Elemwise{mul,no_inplace} [@CB] ''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论