提交 d92c3675 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

fix(numba): cholesky did not set off-diag entries to zero

上级 75a9fd2a
......@@ -310,8 +310,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
return (
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)
......@@ -346,6 +349,15 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO,
)
if lower:
for j in range(1, _N):
for i in range(j):
A_copy[i, j] = 0.0
else:
for j in range(_N):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0
return A_copy, int_ptr_to_val(INFO)
return impl
......
......@@ -6,10 +6,8 @@ import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, FunctionGraph
from pytensor.graph import FunctionGraph
from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.test_extra_ops import set_test_value
numba = pytest.importorskip("numba")
......@@ -109,23 +107,22 @@ def test_solve_triangular_raises_on_nan_inf(value):
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky(lower):
x = set_test_value(
pt.tensor(dtype=config.floatX, shape=(3, 3)),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
)
@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
def test_numba_Cholesky(lower, trans):
cov = pt.matrix("cov")
g = pt.linalg.cholesky(x, lower=lower)
g_fg = FunctionGraph(outputs=[g])
if trans:
cov_ = cov.T
else:
cov_ = cov
chol = pt.linalg.cholesky(cov_, lower=lower)
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
fg = FunctionGraph(outputs=[chol])
x = np.array([0.1, 0.2, 0.3])
val = np.eye(3) + x[None, :] * x[:, None]
compare_numba_and_py(fg, [val])
def test_numba_Cholesky_raises_on_nan_input():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论