提交 d64cceaf authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Work-around numba pow failure was not robust enough

上级 695574b9
...@@ -170,18 +170,17 @@ def {binary_op_name}({input_signature}): ...@@ -170,18 +170,17 @@ def {binary_op_name}({input_signature}):
@register_funcify_and_cache_key(Pow) @register_funcify_and_cache_key(Pow)
def numba_funcify_Pow(op, node, **kwargs): def numba_funcify_Pow(op, node, **kwargs):
pow_dtype = node.inputs[1].type.dtype pow_dtype = node.inputs[1].type.dtype
if pow_dtype.startswith("int"):
# Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True
# https://github.com/numba/numba/issues/9554
def pow(x, y):
return x ** np.asarray(y, dtype=np.int64).item()
else:
def pow(x, y): def pow(x, y):
return x**y return x**y
return numba_basic.numba_njit(pow), scalar_op_cache_key(op) # Numba power fails when exponents are discrete integers and fasthmath=True
# https://github.com/numba/numba/issues/9554
fastmath = False if np.dtype(pow_dtype).kind in "ibu" else None
return numba_basic.numba_njit(pow, fastmath=fastmath), scalar_op_cache_key(
op, cache_version=1
)
@register_funcify_and_cache_key(Add) @register_funcify_and_cache_key(Add)
......
...@@ -183,13 +183,23 @@ def test_Softplus(dtype): ...@@ -183,13 +183,23 @@ def test_Softplus(dtype):
) )
def test_discrete_power(): @pytest.mark.parametrize(
"test_base",
[np.bool(True), np.int16(3), np.uint16(3), np.float32(0.5), np.float64(0.5)],
)
@pytest.mark.parametrize(
"test_exponent",
[np.bool(True), np.int16(2), np.uint16(2), np.float32(2.0), np.float64(2.0)],
)
def test_power_fastmath_bug(test_base, test_exponent):
# Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554 # Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554
x = pt.scalar("x", dtype="float64") base = pt.scalar("base", dtype=test_base.dtype)
exponent = pt.scalar("exponent", dtype="int8") exponent = pt.scalar("exponent", dtype=test_exponent.dtype)
out = pt.power(x, exponent) out = pt.power(base, exponent)
compare_numba_and_py( compare_numba_and_py(
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")] [base, exponent],
[out],
[test_base, test_exponent],
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论