提交 0a082171 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2349 from nouiz/arange_infer_shape

[BUG]fix gh-2348
...@@ -4182,6 +4182,7 @@ class ARange(Op): ...@@ -4182,6 +4182,7 @@ class ARange(Op):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def infer_shape(self, node, i_shapes): def infer_shape(self, node, i_shapes):
# Note start, stop and step can be float numbers.
start, stop, step = node.inputs start, stop, step = node.inputs
def is_constant_value(var, value): def is_constant_value(var, value):
...@@ -4192,12 +4193,27 @@ class ARange(Op): ...@@ -4192,12 +4193,27 @@ class ARange(Op):
pass pass
return False return False
def upcast(var):
if ('int' in var.dtype and
# We do not want to cast uint64 to int64 as this can
# loose information. If we upcast uint64 with int64,
# this give float64. This is safer then checking for
# uint64 in case we support [u]int128 or other in the
# future.
scal.upcast(var.dtype, 'int64') == 'int64'):
return cast(var, 'int64')
return var
if is_constant_value(step, 1): if is_constant_value(step, 1):
if is_constant_value(start, 0): if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)] return [(cast(stop, 'int64'),)]
else: else:
stop = upcast(stop)
start = upcast(start)
return [(maximum(cast(stop - start, 'int64'), 0),)] return [(maximum(cast(stop - start, 'int64'), 0),)]
else: else:
stop = upcast(stop)
start = upcast(start)
return [(maximum(cast(ceil(cast((stop - start), 'float64') return [(maximum(cast(ceil(cast((stop - start), 'float64')
/ step), 'int64'), 0),)] / step), 'int64'), 0),)]
......
...@@ -5101,8 +5101,7 @@ class TestARange(unittest.TestCase): ...@@ -5101,8 +5101,7 @@ class TestARange(unittest.TestCase):
mode = 'FAST_RUN' mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).excluding('fusion') mode = compile.mode.get_mode(mode).excluding('fusion')
f = function([start, stop, step], out.shape, mode=mode) f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 7 assert len(f.maker.fgraph.toposort()) == 9
#7 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{float64}}(Elemwise{sub,no_inplace}.0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, step), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
...@@ -5123,7 +5122,7 @@ class TestARange(unittest.TestCase): ...@@ -5123,7 +5122,7 @@ class TestARange(unittest.TestCase):
out = arange(start, stop, 1) out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode) f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 4 assert len(f.maker.fgraph.toposort()) == 5
#4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] #4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
...@@ -5138,6 +5137,9 @@ class TestARange(unittest.TestCase): ...@@ -5138,6 +5137,9 @@ class TestARange(unittest.TestCase):
assert numpy.all(f(10, 2) == len(numpy.arange(10, 2))) assert numpy.all(f(10, 2) == len(numpy.arange(10, 2)))
assert numpy.all(f(10, 2) == len(numpy.arange(10, 2))) assert numpy.all(f(10, 2) == len(numpy.arange(10, 2)))
assert numpy.all(f(0, 0) == len(numpy.arange(0, 0))) assert numpy.all(f(0, 0) == len(numpy.arange(0, 0)))
assert numpy.all(f(-64, 64) == len(numpy.arange(-64, 64)))
assert arange(-64, 64).shape.eval() == [128]
assert arange(-64, 64, 2).shape.eval() == [64]
out = arange(0, stop, 1) out = arange(0, stop, 1)
f = function([stop], out.shape, mode=mode) f = function([stop], out.shape, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论