提交 1efa92fd authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba cholesky

上级 6499a2c1
import numpy as np
from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numba.types import Float
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
......@@ -19,7 +19,7 @@ def _cholesky(a, lower=False, overwrite_a=False):
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="cholesky")
dtype = A.dtype
numba_potrf = _LAPACK().numba_xpotrf(dtype)
......@@ -33,7 +33,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False):
if overwrite_a and A.flags.f_contiguous:
A_copy = A
elif overwrite_a and A.flags.c_contiguous:
# We can work on the transpose of A directly
# c_contiguous A reinterpreted as f_contiguous is A^T.
# potrf(A^T, UPLO='U') produces U where U.T == L (the correct lower factor),
# even for complex Hermitian matrices. The .T return corrects the result.
A_copy = A.T
transposed = True
lower = not lower
......
......@@ -465,12 +465,20 @@ class TestDecompositions:
@pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
)
def test_cholesky(self, lower: bool, overwrite_a: bool):
cov = pt.matrix("cov")
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_cholesky(self, lower: bool, overwrite_a: bool, is_complex: bool):
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
cov = pt.matrix("cov", dtype=dtype)
chol = pt.linalg.cholesky(cov, lower=lower)
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
rng = np.random.default_rng(42)
x = rng.normal(size=(3, 3))
if is_complex:
x = x + 1j * rng.normal(size=(3, 3))
x = x.astype(dtype)
val = np.eye(3, dtype=dtype) + x @ x.conj().T
fn, res = compare_numba_and_py(
[In(cov, mutable=overwrite_a)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论