提交 d92c3675 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

fix(numba): cholesky did not set off-diag entries to zero

上级 75a9fd2a
...@@ -310,8 +310,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -310,8 +310,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky( return (
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
) )
...@@ -346,6 +349,15 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -346,6 +349,15 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO, INFO,
) )
if lower:
for j in range(1, _N):
for i in range(j):
A_copy[i, j] = 0.0
else:
for j in range(_N):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0
return A_copy, int_ptr_to_val(INFO) return A_copy, int_ptr_to_val(INFO)
return impl return impl
......
...@@ -6,10 +6,8 @@ import pytest ...@@ -6,10 +6,8 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.compile import SharedVariable from pytensor.graph import FunctionGraph
from pytensor.graph import Constant, FunctionGraph
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.test_extra_ops import set_test_value
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
...@@ -109,23 +107,22 @@ def test_solve_triangular_raises_on_nan_inf(value): ...@@ -109,23 +107,22 @@ def test_solve_triangular_raises_on_nan_inf(value):
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) @pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky(lower): @pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
x = set_test_value( def test_numba_Cholesky(lower, trans):
pt.tensor(dtype=config.floatX, shape=(3, 3)), cov = pt.matrix("cov")
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
)
g = pt.linalg.cholesky(x, lower=lower) if trans:
g_fg = FunctionGraph(outputs=[g]) cov_ = cov.T
else:
cov_ = cov
chol = pt.linalg.cholesky(cov_, lower=lower)
compare_numba_and_py( fg = FunctionGraph(outputs=[chol])
g_fg,
[ x = np.array([0.1, 0.2, 0.3])
i.tag.test_value val = np.eye(3) + x[None, :] * x[:, None]
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant) compare_numba_and_py(fg, [val])
],
)
def test_numba_Cholesky_raises_on_nan_input(): def test_numba_Cholesky_raises_on_nan_input():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论