提交 165aa19f authored 作者: Vybhav's avatar Vybhav 提交者: Ricardo Vieira

Fix scalar Pow.impl returning complex for negative base with fractional exponent

上级 30558ba5
...@@ -2282,7 +2282,7 @@ class Pow(BinaryScalarOp): ...@@ -2282,7 +2282,7 @@ class Pow(BinaryScalarOp):
nfunc_spec = ("power", 2, 1) nfunc_spec = ("power", 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x**y return np.power(x, y)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs (x, y) = inputs
......
...@@ -34,6 +34,7 @@ from pytensor.scalar.basic import ( ...@@ -34,6 +34,7 @@ from pytensor.scalar.basic import (
exp2, exp2,
expm1, expm1,
float32, float32,
float64,
floats, floats,
int8, int8,
int32, int32,
...@@ -47,6 +48,7 @@ from pytensor.scalar.basic import ( ...@@ -47,6 +48,7 @@ from pytensor.scalar.basic import (
mul, mul,
neg, neg,
neq, neq,
pow,
rad2deg, rad2deg,
reciprocal, reciprocal,
sin, sin,
...@@ -530,3 +532,25 @@ def test_cast_to_complex(inp_type): ...@@ -530,3 +532,25 @@ def test_cast_to_complex(inp_type):
res_y = y.eval({x: np.array(1.0, dtype=inp_type.dtype)}) res_y = y.eval({x: np.array(1.0, dtype=inp_type.dtype)})
assert res_y == 1 assert res_y == 1
assert res_y.dtype == "complex64" assert res_y.dtype == "complex64"
@pytest.mark.parametrize("mode", [Mode(linker="py"), None])
def test_pow_negative_base_fractional_exponent(mode):
x = float64("x")
y = float64("y")
f = pytensor.function([x, y], pow(x, y), mode=mode)
# Positive base works normally
assert f(2.0, 3.0) == 8.0
# Negative base with integer exponent works
assert f(-2.0, 3.0) == -8.0
# Negative base with fractional exponent returns nan, not complex
result = f(-1.0, 0.01)
if not isinstance(f.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert isinstance(result, np.floating), (
f"Expected numpy float, got {type(result)}: {result}"
)
assert np.isnan(result), f"Expected nan, got {result}"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论