提交 db673f0d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix mixed dtype bug in gammaincc_grad

上级 53b00ea6
......@@ -854,7 +854,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
log_s = np.array(0.0, dtype=dtype)
s_sign = np.array(1, dtype="int8")
n = np.array(1, dtype="int32")
log_delta = log_s - 2 * log(k)
log_delta = log_s - 2 * log(k).astype(dtype)
def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
delta = exp(log_delta)
......
import itertools
import numpy as np
import pytest
import scipy.special as sp
import pytensor.tensor as at
from pytensor import function
from pytensor.compile.mode import Mode
from pytensor.graph import ancestors
from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import CLinker
from pytensor.scalar import ScalarLoop, float32, float64, int32
from pytensor.scalar.math import (
betainc,
betainc_grad,
......@@ -13,6 +18,7 @@ from pytensor.scalar.math import (
gammaincc,
gammal,
gammau,
hyp2f1,
)
from tests.link.test_link import make_function
......@@ -89,3 +95,32 @@ def test_betainc_derivative_nan():
assert np.isnan(test_func(1, 1, 2))
assert np.isnan(test_func(1, -1, 1))
assert np.isnan(test_func(1, 1, -1))
@pytest.mark.parametrize(
"op, scalar_loop_grads",
[
(gammainc, [0]),
(gammaincc, [0]),
(betainc, [0, 1]),
(hyp2f1, [0, 1, 2]),
],
)
def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
nin = op.nin
for types in itertools.product((float32, float64, int32), repeat=nin):
inputs = [type() for type in types]
out = op(*inputs)
wrt = [
inp
for idx, inp in enumerate(inputs)
if idx in scalar_loop_grads and inp.type.dtype.startswith("float")
]
if not wrt:
continue
# The ScalarLoop in the graph will fail if the input types are different from the updates
grad = at.grad(out, wrt=wrt)
assert any(
(var.owner and isinstance(var.owner.op, ScalarLoop))
for var in ancestors(grad)
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论