提交 42cd329d authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Frederic

[BUG]fix gh-2348

when constant are used, we use the smallest dtype for them. But this isn't enought when computing on them. So upcast to int64 to make sure compuration is right.
上级 5895d5a2
......@@ -4182,6 +4182,7 @@ class ARange(Op):
return Apply(self, inputs, outputs)
def infer_shape(self, node, i_shapes):
# Note start, stop and step can be float numbers.
start, stop, step = node.inputs
def is_constant_value(var, value):
......@@ -4196,8 +4197,16 @@ class ARange(Op):
if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)]
else:
if 'int' in stop.dtype:
stop = cast(stop, 'int64')
elif 'int' in start.dtype:
start = cast(start, 'int64')
return [(maximum(cast(stop - start, 'int64'), 0),)]
else:
if 'int' in stop.dtype:
stop = cast(stop, 'int64')
elif 'int' in start.dtype:
start = cast(start, 'int64')
return [(maximum(cast(ceil(cast((stop - start), 'float64')
/ step), 'int64'), 0),)]
......
......@@ -5101,8 +5101,7 @@ class TestARange(unittest.TestCase):
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).excluding('fusion')
f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 7
#7 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{float64}}(Elemwise{sub,no_inplace}.0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, step), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert len(f.maker.fgraph.toposort()) == 8
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
......@@ -5138,6 +5137,9 @@ class TestARange(unittest.TestCase):
assert numpy.all(f(10, 2) == len(numpy.arange(10, 2)))
assert numpy.all(f(10, 2) == len(numpy.arange(10, 2)))
assert numpy.all(f(0, 0) == len(numpy.arange(0, 0)))
assert numpy.all(f(-64, 64) == len(numpy.arange(-64, 64)))
assert arange(-64, 64).shape.eval() == [128]
assert arange(-64, 64, 2).shape.eval() == [64]
out = arange(0, stop, 1)
f = function([stop], out.shape, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论