提交 9a15b2ef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba QR: Support complex dtype inputs

上级 61380247
......@@ -3,6 +3,7 @@ import ctypes
import numpy as np
from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic
from numba.core.types import Complex
from numba.np.linalg import ensure_lapack, get_blas_kind
......@@ -486,8 +487,7 @@ class _LAPACK:
Used in QR decomposition with pivoting.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE(
None,
ctype_args = (
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
......@@ -496,8 +496,20 @@ class _LAPACK:
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
)
if isinstance(dtype, Complex):
ctype_args = (
*ctype_args,
float_pointer, # RWORK)
)
functype = ctypes.CFUNCTYPE(
None,
*ctype_args,
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
......
......@@ -2,6 +2,7 @@ from typing import Literal
import numpy as np
from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy.linalg import get_lapack_funcs, qr
......@@ -11,6 +12,7 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
......@@ -55,7 +57,7 @@ def xgeqrf_impl(A, overwrite_a, lwork):
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.view(w_type).ctypes,
A_copy.T.view(w_type).T.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
......@@ -107,7 +109,7 @@ def xgeqp3_impl(A, overwrite_a, lwork):
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.view(w_type).ctypes,
A_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
......@@ -160,7 +162,7 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.view(w_type).ctypes,
A_copy.T.view(w_type).T.ctypes,
LDA,
tau.view(w_type).ctypes,
WORK.view(w_type).ctypes,
......@@ -184,6 +186,7 @@ def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
@overload(_xungqr)
def xungqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
ungqr = _LAPACK().numba_xungqr(dtype)
......@@ -211,7 +214,7 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.view(w_type).ctypes,
A_copy.T.view(w_type).T.ctypes,
LDA,
tau.view(w_type).ctypes,
WORK.view(w_type).ctypes,
......@@ -378,10 +381,18 @@ def qr_full_pivot_impl(
x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
is_complex = isinstance(dtype, Complex)
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype)
orgqr = (
_LAPACK().numba_xorgqr(dtype)
if isinstance(dtype, Float)
else _LAPACK().numba_xungqr(dtype)
)
def impl(
x,
......@@ -403,41 +414,71 @@ def qr_full_pivot_impl(
LDA = val_to_int_ptr(M)
TAU = np.empty(K, dtype=dtype)
JPVT = np.zeros(N, dtype=np.int32)
if is_complex:
RWORK = np.empty(2 * N, dtype=w_type)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
if is_complex:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1), # LWORK
RWORK.ctypes,
val_to_int_ptr(1), # INFO
)
else:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item().real)
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
if is_complex:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
val_to_int_ptr(lwork_val),
RWORK.ctypes,
INFO,
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
JPVT = (JPVT - 1).astype(np.int32)
if mode == "full" or M < N:
......@@ -460,14 +501,14 @@ def qr_full_pivot_impl(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.view(w_type).ctypes,
Q_in.T.view(w_type).T.ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_q = int(WORKQ.item())
lwork_q = int(WORKQ.item().real)
else:
lwork_q = lwork
......@@ -478,7 +519,7 @@ def qr_full_pivot_impl(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.view(w_type).ctypes,
Q_in.T.view(w_type).T.ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
......@@ -495,10 +536,15 @@ def qr_full_no_pivot_impl(
x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype)
orgqr = (
_LAPACK().numba_xorgqr(dtype)
if isinstance(dtype, Float)
else _LAPACK().numba_xungqr(dtype)
)
def impl(
x,
......@@ -528,14 +574,14 @@ def qr_full_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
lwork_val = int(WORK.item().real)
else:
lwork_val = lwork
......@@ -545,7 +591,7 @@ def qr_full_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
......@@ -573,14 +619,14 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.view(w_type).ctypes,
Q_in.T.view(w_type).T.ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_q = int(WORKQ.item())
lwork_q = int(WORKQ.real.item())
else:
lwork_q = lwork
......@@ -591,7 +637,7 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M), # M
val_to_int_ptr(Q_in.shape[1]), # N
val_to_int_ptr(K), # K
Q_in.view(w_type).ctypes, # A
Q_in.T.view(w_type).T.ctypes, # A
val_to_int_ptr(M), # LDA
TAU.view(w_type).ctypes, # TAU
WORKQ.view(w_type).ctypes, # WORK
......@@ -608,6 +654,7 @@ def qr_r_pivot_impl(
x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
......@@ -640,7 +687,7 @@ def qr_r_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
......@@ -648,7 +695,7 @@ def qr_r_pivot_impl(
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
lwork_val = int(WORK.item().real)
else:
lwork_val = lwork
......@@ -658,7 +705,7 @@ def qr_r_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
......@@ -683,6 +730,7 @@ def qr_r_no_pivot_impl(
x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
......@@ -714,14 +762,14 @@ def qr_r_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
lwork_val = int(WORK.item().real)
else:
lwork_val = lwork
......@@ -731,7 +779,7 @@ def qr_r_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
......@@ -755,6 +803,7 @@ def qr_raw_no_pivot_impl(
x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
......@@ -786,14 +835,14 @@ def qr_raw_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
lwork_val = int(WORK.item().real)
else:
lwork_val = lwork
......@@ -803,7 +852,7 @@ def qr_raw_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
......@@ -826,7 +875,11 @@ def qr_raw_pivot_impl(
x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
is_complex = isinstance(dtype, Complex)
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
......@@ -850,40 +903,70 @@ def qr_raw_pivot_impl(
K = min(M, N)
TAU = np.empty(K, dtype=dtype)
JPVT = np.zeros(N, dtype=np.int32)
if is_complex:
RWORK = np.empty(2 * N, dtype=w_type)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
if is_complex:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1), # LWORK
RWORK.ctypes,
val_to_int_ptr(1), # INFO
)
else:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item().real)
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
if is_complex:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
val_to_int_ptr(lwork_val),
RWORK.ctypes,
INFO,
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
JPVT = (JPVT - 1).astype(np.int32)
......
......@@ -42,7 +42,6 @@ from pytensor.tensor.slinalg import (
Solve,
SolveTriangular,
)
from pytensor.tensor.type import complex_dtypes, integer_dtypes
@numba_funcify.register(Cholesky)
......@@ -418,12 +417,12 @@ def numba_funcify_QR(op, node, **kwargs):
pivoting = op.pivoting
overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
return generate_fallback_impl(op, node=node, **kwargs)
in_dtype = node.inputs[0].type.numpy_dtype
integer_input = in_dtype.kind in "ibu"
if integer_input and config.compiler_verbose:
print("QR requires casting discrete input to float") # noqa: T201
integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype
out_dtype = node.outputs[0].type.numpy_dtype
@numba_basic.numba_njit
def qr(a):
......@@ -434,7 +433,7 @@ def numba_funcify_QR(op, node, **kwargs):
)
if integer_input:
a = a.astype(in_dtype)
a = a.astype(out_dtype)
if (mode == "full" or mode == "economic") and pivoting:
Q, R, P = _qr_full_pivot(
......
......@@ -1824,7 +1824,10 @@ class QR(Op):
K = None
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
if in_dtype.kind in "ibu":
out_dtype = "float64" if in_dtype.itemsize > 2 else "float32"
else:
out_dtype = "float64" if in_dtype.itemsize > 4 else "float32"
match self.mode:
case "full":
......
......@@ -718,17 +718,21 @@ class TestDecompositions:
ids=["economic", "full_pivot", "r", "raw_pivot"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
"overwrite_a", [False, True], ids=["overwrite_a", "no_overwrite"]
)
def test_qr(self, mode, pivoting, overwrite_a):
@pytest.mark.parametrize("complex", (False, True))
def test_qr(self, mode, pivoting, overwrite_a, complex):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
dtype="complex128" if complex else "float64",
)
A_val = rng.normal(size=shape).astype(config.floatX)
if complex:
A_val = rng.normal(size=(*shape, 2)).view(dtype=A.dtype).squeeze(-1)
else:
A_val = rng.normal(size=shape).astype(A.dtype)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论