提交 1efa92fd authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba cholesky

上级 6499a2c1
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numba.types import Float
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
...@@ -19,7 +19,7 @@ def _cholesky(a, lower=False, overwrite_a=False): ...@@ -19,7 +19,7 @@ def _cholesky(a, lower=False, overwrite_a=False):
@overload(_cholesky) @overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False): def cholesky_impl(A, lower=0, overwrite_a=False):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky") _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="cholesky")
dtype = A.dtype dtype = A.dtype
numba_potrf = _LAPACK().numba_xpotrf(dtype) numba_potrf = _LAPACK().numba_xpotrf(dtype)
...@@ -33,7 +33,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False): ...@@ -33,7 +33,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False):
if overwrite_a and A.flags.f_contiguous: if overwrite_a and A.flags.f_contiguous:
A_copy = A A_copy = A
elif overwrite_a and A.flags.c_contiguous: elif overwrite_a and A.flags.c_contiguous:
# We can work on the transpose of A directly # c_contiguous A reinterpreted as f_contiguous is A^T.
# potrf(A^T, UPLO='U') produces U where U.T == L (the correct lower factor),
# even for complex Hermitian matrices. The .T return corrects the result.
A_copy = A.T A_copy = A.T
transposed = True transposed = True
lower = not lower lower = not lower
......
...@@ -465,12 +465,20 @@ class TestDecompositions: ...@@ -465,12 +465,20 @@ class TestDecompositions:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"] "overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
) )
def test_cholesky(self, lower: bool, overwrite_a: bool): @pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
cov = pt.matrix("cov") def test_cholesky(self, lower: bool, overwrite_a: bool, is_complex: bool):
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
cov = pt.matrix("cov", dtype=dtype)
chol = pt.linalg.cholesky(cov, lower=lower) chol = pt.linalg.cholesky(cov, lower=lower)
x = np.array([0.1, 0.2, 0.3]).astype(floatX) rng = np.random.default_rng(42)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] x = rng.normal(size=(3, 3))
if is_complex:
x = x + 1j * rng.normal(size=(3, 3))
x = x.astype(dtype)
val = np.eye(3, dtype=dtype) + x @ x.conj().T
fn, res = compare_numba_and_py( fn, res = compare_numba_and_py(
[In(cov, mutable=overwrite_a)], [In(cov, mutable=overwrite_a)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论