Unverified 提交 370b172c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Numba Pow: Fix failure with discrete integer exponents (#1758)

上级 ae499a49
...@@ -22,6 +22,7 @@ from pytensor.scalar.basic import ( ...@@ -22,6 +22,7 @@ from pytensor.scalar.basic import (
Composite, Composite,
Identity, Identity,
Mul, Mul,
Pow,
Reciprocal, Reciprocal,
ScalarOp, ScalarOp,
Second, Second,
...@@ -165,6 +166,23 @@ def {binary_op_name}({input_signature}): ...@@ -165,6 +166,23 @@ def {binary_op_name}({input_signature}):
return nary_fn return nary_fn
@register_funcify_and_cache_key(Pow)
def numba_funcify_Pow(op, node, **kwargs):
pow_dtype = node.inputs[1].type.dtype
if pow_dtype.startswith("int"):
# Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True
# https://github.com/numba/numba/issues/9554
def pow(x, y):
return x ** np.asarray(y, dtype=np.int64).item()
else:
def pow(x, y):
return x**y
return numba_basic.numba_njit(pow), scalar_op_cache_key(op)
@register_funcify_and_cache_key(Add) @register_funcify_and_cache_key(Add)
def numba_funcify_Add(op, node, **kwargs): def numba_funcify_Add(op, node, **kwargs):
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
......
...@@ -189,6 +189,16 @@ def test_Softplus(dtype): ...@@ -189,6 +189,16 @@ def test_Softplus(dtype):
) )
def test_discrete_power():
# Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554
x = pt.scalar("x", dtype="float64")
exponent = pt.scalar("exponent", dtype="int8")
out = pt.power(x, exponent)
compare_numba_and_py(
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")]
)
def test_cython_obj_mode_fallback(): def test_cython_obj_mode_fallback():
"""Test that unsupported cython signatures fallback to obj-mode""" """Test that unsupported cython signatures fallback to obj-mode"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论