提交 576dfaf3 authored 作者: Frederic's avatar Frederic

arange.infer_shape upcast more and be safe again uint64. The old upcast was ok, as the elemw

上级 42cd329d
......@@ -4193,20 +4193,27 @@ class ARange(Op):
pass
return False
def upcast(var):
if ('int' in var.dtype and
# We do not want to cast uint64 to int64 as this can
# loose information. If we upcast uint64 with int64,
# this give float64. This is safer then checking for
# uint64 in case we support [u]int128 or other in the
# future.
scal.upcast(var.dtype, 'int64') == 'int64'):
return cast(var, 'int64')
return var
if is_constant_value(step, 1):
if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)]
else:
if 'int' in stop.dtype:
stop = cast(stop, 'int64')
elif 'int' in start.dtype:
start = cast(start, 'int64')
stop = upcast(stop)
start = upcast(start)
return [(maximum(cast(stop - start, 'int64'), 0),)]
else:
if 'int' in stop.dtype:
stop = cast(stop, 'int64')
elif 'int' in start.dtype:
start = cast(start, 'int64')
stop = upcast(stop)
start = upcast(start)
return [(maximum(cast(ceil(cast((stop - start), 'float64')
/ step), 'int64'), 0),)]
......
......@@ -5101,7 +5101,7 @@ class TestARange(unittest.TestCase):
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).excluding('fusion')
f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 8
assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
......@@ -5122,7 +5122,7 @@ class TestARange(unittest.TestCase):
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 4
assert len(f.maker.fgraph.toposort()) == 5
#4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论