提交 165aa19f authored 作者: Vybhav's avatar Vybhav 提交者: Ricardo Vieira

Fix scalar Pow.impl returning complex for negative base with fractional exponent

上级 30558ba5
......@@ -2282,7 +2282,7 @@ class Pow(BinaryScalarOp):
nfunc_spec = ("power", 2, 1)
def impl(self, x, y):
return x**y
return np.power(x, y)
def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs
......
......@@ -34,6 +34,7 @@ from pytensor.scalar.basic import (
exp2,
expm1,
float32,
float64,
floats,
int8,
int32,
......@@ -47,6 +48,7 @@ from pytensor.scalar.basic import (
mul,
neg,
neq,
pow,
rad2deg,
reciprocal,
sin,
......@@ -530,3 +532,25 @@ def test_cast_to_complex(inp_type):
res_y = y.eval({x: np.array(1.0, dtype=inp_type.dtype)})
assert res_y == 1
assert res_y.dtype == "complex64"
@pytest.mark.parametrize("mode", [Mode(linker="py"), None])
def test_pow_negative_base_fractional_exponent(mode):
x = float64("x")
y = float64("y")
f = pytensor.function([x, y], pow(x, y), mode=mode)
# Positive base works normally
assert f(2.0, 3.0) == 8.0
# Negative base with integer exponent works
assert f(-2.0, 3.0) == -8.0
# Negative base with fractional exponent returns nan, not complex
result = f(-1.0, 0.01)
if not isinstance(f.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert isinstance(result, np.floating), (
f"Expected numpy float, got {type(result)}: {result}"
)
assert np.isnan(result), f"Expected nan, got {result}"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论