提交 576dfaf3 authored 作者: Frederic's avatar Frederic

arange.infer_shape upcast more and be safe again uint64. The old upcast was ok, as the elemw

上级 42cd329d
...@@ -4193,20 +4193,27 @@ class ARange(Op): ...@@ -4193,20 +4193,27 @@ class ARange(Op):
pass pass
return False return False
def upcast(var):
if ('int' in var.dtype and
# We do not want to cast uint64 to int64 as this can
# loose information. If we upcast uint64 with int64,
# this give float64. This is safer then checking for
# uint64 in case we support [u]int128 or other in the
# future.
scal.upcast(var.dtype, 'int64') == 'int64'):
return cast(var, 'int64')
return var
if is_constant_value(step, 1): if is_constant_value(step, 1):
if is_constant_value(start, 0): if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)] return [(cast(stop, 'int64'),)]
else: else:
if 'int' in stop.dtype: stop = upcast(stop)
stop = cast(stop, 'int64') start = upcast(start)
elif 'int' in start.dtype:
start = cast(start, 'int64')
return [(maximum(cast(stop - start, 'int64'), 0),)] return [(maximum(cast(stop - start, 'int64'), 0),)]
else: else:
if 'int' in stop.dtype: stop = upcast(stop)
stop = cast(stop, 'int64') start = upcast(start)
elif 'int' in start.dtype:
start = cast(start, 'int64')
return [(maximum(cast(ceil(cast((stop - start), 'float64') return [(maximum(cast(ceil(cast((stop - start), 'float64')
/ step), 'int64'), 0),)] / step), 'int64'), 0),)]
......
...@@ -5101,7 +5101,7 @@ class TestARange(unittest.TestCase): ...@@ -5101,7 +5101,7 @@ class TestARange(unittest.TestCase):
mode = 'FAST_RUN' mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).excluding('fusion') mode = compile.mode.get_mode(mode).excluding('fusion')
f = function([start, stop, step], out.shape, mode=mode) f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 8 assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
...@@ -5122,7 +5122,7 @@ class TestARange(unittest.TestCase): ...@@ -5122,7 +5122,7 @@ class TestARange(unittest.TestCase):
out = arange(start, stop, 1) out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode) f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 4 assert len(f.maker.fgraph.toposort()) == 5
#4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] #4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论