提交 0a082171 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2349 from nouiz/arange_infer_shape

[BUG]fix gh-2348
......@@ -4182,6 +4182,7 @@ class ARange(Op):
return Apply(self, inputs, outputs)
def infer_shape(self, node, i_shapes):
# Note start, stop and step can be float numbers.
start, stop, step = node.inputs
def is_constant_value(var, value):
......@@ -4192,12 +4193,27 @@ class ARange(Op):
pass
return False
def upcast(var):
if ('int' in var.dtype and
# We do not want to cast uint64 to int64 as this can
# loose information. If we upcast uint64 with int64,
# this give float64. This is safer then checking for
# uint64 in case we support [u]int128 or other in the
# future.
scal.upcast(var.dtype, 'int64') == 'int64'):
return cast(var, 'int64')
return var
if is_constant_value(step, 1):
if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)]
else:
stop = upcast(stop)
start = upcast(start)
return [(maximum(cast(stop - start, 'int64'), 0),)]
else:
stop = upcast(stop)
start = upcast(start)
return [(maximum(cast(ceil(cast((stop - start), 'float64')
/ step), 'int64'), 0),)]
......
......@@ -5101,8 +5101,7 @@ class TestARange(unittest.TestCase):
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).excluding('fusion')
f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 7
#7 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{float64}}(Elemwise{sub,no_inplace}.0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, step), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
......@@ -5123,7 +5122,7 @@ class TestARange(unittest.TestCase):
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 4
assert len(f.maker.fgraph.toposort()) == 5
#4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
......@@ -5138,6 +5137,9 @@ class TestARange(unittest.TestCase):
assert numpy.all(f(10, 2) == len(numpy.arange(10, 2)))
assert numpy.all(f(10, 2) == len(numpy.arange(10, 2)))
assert numpy.all(f(0, 0) == len(numpy.arange(0, 0)))
assert numpy.all(f(-64, 64) == len(numpy.arange(-64, 64)))
assert arange(-64, 64).shape.eval() == [128]
assert arange(-64, 64, 2).shape.eval() == [64]
out = arange(0, stop, 1)
f = function([stop], out.shape, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论