提交 2b598f80 authored 作者: Frederic's avatar Frederic

Fix ceil_intdiv output dtype when all inputs dtype are uint*.

Also rewrote with Olivier D. the comments and docstring to make it more clear.
上级 6c0f4709
......@@ -2824,10 +2824,23 @@ def int_div(a, b):
def ceil_intdiv(a, b):
""" return ceil(a/b) when a and b are int """
# Is it faster to cast to float when this don't loose precission?
# return cast(cast(a, scalar.upcast(a, 'float32')) / b, scal.upcast(a, b))
return int_div(a, b) + neq(a % b, 0)
"""
Safely compute ceil(float_division(a, b)).
Works for all dtypes, but mostly useful when a and b are int.
"""
# If a and b are int with not many significant bits, we could
# cast them to float to avoid doing the module. We do not know if this
# is faster or not. But this is not safe for int64 as the cast will
# lose precision.
# e.g.: cast(cast(a, scalar.upcast(a, 'float32')) / b, scal.upcast(a, b))
# We cast for the case when a and b are uint*. Otherwise neq will
# force their upcast to int.
div = int_div(a, b)
ret = cast(neq(a % b, 0), div.dtype) + div
assert ret.dtype == scal.upcast(a.dtype, b.dtype)
return ret
def mod_check(x, y):
......
......@@ -627,7 +627,11 @@ TrueDivInplaceTester = makeBroadcastTester(op = inplace.true_div_inplace,
CeilIntDivTester = makeBroadcastTester(
op=tensor.ceil_intdiv,
expected=lambda x, y: check_floatX((x, y), (x // y) + ((x % y) != 0)),
good=_good_broadcast_div_mod_normal_float_no_complex,
good=copymod(_good_broadcast_div_mod_normal_float_no_complex,
integer=(randint(2, 3), randint_nonzero(2, 3)),
uinteger=(randint(2, 3).astype("uint8"),
randint_nonzero(2, 3).astype("uint8")),
),
# As we implement this function with neq, the gradient returned is always 0.
# grad=_grad_broadcast_div_mod_normal,
# grad_rtol=div_grad_rtol,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论