提交 9a15b2ef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba QR: Support complex dtype inputs

上级 61380247
...@@ -3,6 +3,7 @@ import ctypes ...@@ -3,6 +3,7 @@ import ctypes
import numpy as np import numpy as np
from numba.core import cgutils, types from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic from numba.core.extending import get_cython_function_address, intrinsic
from numba.core.types import Complex
from numba.np.linalg import ensure_lapack, get_blas_kind from numba.np.linalg import ensure_lapack, get_blas_kind
...@@ -486,8 +487,7 @@ class _LAPACK: ...@@ -486,8 +487,7 @@ class _LAPACK:
Used in QR decomposition with pivoting. Used in QR decomposition with pivoting.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3") lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE( ctype_args = (
None,
_ptr_int, # M _ptr_int, # M
_ptr_int, # N _ptr_int, # N
float_pointer, # A float_pointer, # A
...@@ -496,8 +496,20 @@ class _LAPACK: ...@@ -496,8 +496,20 @@ class _LAPACK:
float_pointer, # TAU float_pointer, # TAU
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # LWORK _ptr_int, # LWORK
)
if isinstance(dtype, Complex):
ctype_args = (
*ctype_args,
float_pointer, # RWORK)
)
functype = ctypes.CFUNCTYPE(
None,
*ctype_args,
_ptr_int, # INFO _ptr_int, # INFO
) )
return functype(lapack_ptr) return functype(lapack_ptr)
@classmethod @classmethod
......
...@@ -2,6 +2,7 @@ from typing import Literal ...@@ -2,6 +2,7 @@ from typing import Literal
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy.linalg import get_lapack_funcs, qr from scipy.linalg import get_lapack_funcs, qr
...@@ -11,6 +12,7 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -11,6 +12,7 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int): def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
...@@ -55,7 +57,7 @@ def xgeqrf_impl(A, overwrite_a, lwork): ...@@ -55,7 +57,7 @@ def xgeqrf_impl(A, overwrite_a, lwork):
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
A_copy.view(w_type).ctypes, A_copy.T.view(w_type).T.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
...@@ -107,7 +109,7 @@ def xgeqp3_impl(A, overwrite_a, lwork): ...@@ -107,7 +109,7 @@ def xgeqp3_impl(A, overwrite_a, lwork):
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
A_copy.view(w_type).ctypes, A_copy.T.view(w_type).T.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
...@@ -160,7 +162,7 @@ def xorgqr_impl(A, tau, overwrite_a, lwork): ...@@ -160,7 +162,7 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
val_to_int_ptr(K), val_to_int_ptr(K),
A_copy.view(w_type).ctypes, A_copy.T.view(w_type).T.ctypes,
LDA, LDA,
tau.view(w_type).ctypes, tau.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
...@@ -184,6 +186,7 @@ def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): ...@@ -184,6 +186,7 @@ def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
@overload(_xungqr) @overload(_xungqr)
def xungqr_impl(A, tau, overwrite_a, lwork): def xungqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
ungqr = _LAPACK().numba_xungqr(dtype) ungqr = _LAPACK().numba_xungqr(dtype)
...@@ -211,7 +214,7 @@ def xungqr_impl(A, tau, overwrite_a, lwork): ...@@ -211,7 +214,7 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
val_to_int_ptr(K), val_to_int_ptr(K),
A_copy.view(w_type).ctypes, A_copy.T.view(w_type).T.ctypes,
LDA, LDA,
tau.view(w_type).ctypes, tau.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
...@@ -378,10 +381,18 @@ def qr_full_pivot_impl( ...@@ -378,10 +381,18 @@ def qr_full_pivot_impl(
x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
is_complex = isinstance(dtype, Complex)
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype) geqp3 = _LAPACK().numba_xgeqp3(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype) orgqr = (
_LAPACK().numba_xorgqr(dtype)
if isinstance(dtype, Float)
else _LAPACK().numba_xungqr(dtype)
)
def impl( def impl(
x, x,
...@@ -403,16 +414,32 @@ def qr_full_pivot_impl( ...@@ -403,16 +414,32 @@ def qr_full_pivot_impl(
LDA = val_to_int_ptr(M) LDA = val_to_int_ptr(M)
TAU = np.empty(K, dtype=dtype) TAU = np.empty(K, dtype=dtype)
JPVT = np.zeros(N, dtype=np.int32) JPVT = np.zeros(N, dtype=np.int32)
if is_complex:
RWORK = np.empty(2 * N, dtype=w_type)
if lwork is None: if lwork is None:
lwork = -1 lwork = -1
if lwork == -1: if lwork == -1:
WORK = np.empty(1, dtype=dtype) WORK = np.empty(1, dtype=dtype)
if is_complex:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1), # LWORK
RWORK.ctypes,
val_to_int_ptr(1), # INFO
)
else:
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
...@@ -420,17 +447,31 @@ def qr_full_pivot_impl( ...@@ -420,17 +447,31 @@ def qr_full_pivot_impl(
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_val = int(WORK.item()) lwork_val = int(WORK.item().real)
else: else:
lwork_val = lwork lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype) WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1) INFO = val_to_int_ptr(1)
if is_complex:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
RWORK.ctypes,
INFO,
)
else:
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
...@@ -460,14 +501,14 @@ def qr_full_pivot_impl( ...@@ -460,14 +501,14 @@ def qr_full_pivot_impl(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]), val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K), val_to_int_ptr(K),
Q_in.view(w_type).ctypes, Q_in.T.view(w_type).T.ctypes,
val_to_int_ptr(M), val_to_int_ptr(M),
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes, WORKQ.view(w_type).ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_q = int(WORKQ.item()) lwork_q = int(WORKQ.item().real)
else: else:
lwork_q = lwork lwork_q = lwork
...@@ -478,7 +519,7 @@ def qr_full_pivot_impl( ...@@ -478,7 +519,7 @@ def qr_full_pivot_impl(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]), val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K), val_to_int_ptr(K),
Q_in.view(w_type).ctypes, Q_in.T.view(w_type).T.ctypes,
val_to_int_ptr(M), val_to_int_ptr(M),
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes, WORKQ.view(w_type).ctypes,
...@@ -495,10 +536,15 @@ def qr_full_no_pivot_impl( ...@@ -495,10 +536,15 @@ def qr_full_no_pivot_impl(
x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype) geqrf = _LAPACK().numba_xgeqrf(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype) orgqr = (
_LAPACK().numba_xorgqr(dtype)
if isinstance(dtype, Float)
else _LAPACK().numba_xungqr(dtype)
)
def impl( def impl(
x, x,
...@@ -528,14 +574,14 @@ def qr_full_no_pivot_impl( ...@@ -528,14 +574,14 @@ def qr_full_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_val = int(WORK.item()) lwork_val = int(WORK.item().real)
else: else:
lwork_val = lwork lwork_val = lwork
...@@ -545,7 +591,7 @@ def qr_full_no_pivot_impl( ...@@ -545,7 +591,7 @@ def qr_full_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
...@@ -573,14 +619,14 @@ def qr_full_no_pivot_impl( ...@@ -573,14 +619,14 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]), val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K), val_to_int_ptr(K),
Q_in.view(w_type).ctypes, Q_in.T.view(w_type).T.ctypes,
val_to_int_ptr(M), val_to_int_ptr(M),
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes, WORKQ.view(w_type).ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_q = int(WORKQ.item()) lwork_q = int(WORKQ.real.item())
else: else:
lwork_q = lwork lwork_q = lwork
...@@ -591,7 +637,7 @@ def qr_full_no_pivot_impl( ...@@ -591,7 +637,7 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M), # M val_to_int_ptr(M), # M
val_to_int_ptr(Q_in.shape[1]), # N val_to_int_ptr(Q_in.shape[1]), # N
val_to_int_ptr(K), # K val_to_int_ptr(K), # K
Q_in.view(w_type).ctypes, # A Q_in.T.view(w_type).T.ctypes, # A
val_to_int_ptr(M), # LDA val_to_int_ptr(M), # LDA
TAU.view(w_type).ctypes, # TAU TAU.view(w_type).ctypes, # TAU
WORKQ.view(w_type).ctypes, # WORK WORKQ.view(w_type).ctypes, # WORK
...@@ -608,6 +654,7 @@ def qr_r_pivot_impl( ...@@ -608,6 +654,7 @@ def qr_r_pivot_impl(
x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype) geqp3 = _LAPACK().numba_xgeqp3(dtype)
...@@ -640,7 +687,7 @@ def qr_r_pivot_impl( ...@@ -640,7 +687,7 @@ def qr_r_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
...@@ -648,7 +695,7 @@ def qr_r_pivot_impl( ...@@ -648,7 +695,7 @@ def qr_r_pivot_impl(
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_val = int(WORK.item()) lwork_val = int(WORK.item().real)
else: else:
lwork_val = lwork lwork_val = lwork
...@@ -658,7 +705,7 @@ def qr_r_pivot_impl( ...@@ -658,7 +705,7 @@ def qr_r_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
...@@ -683,6 +730,7 @@ def qr_r_no_pivot_impl( ...@@ -683,6 +730,7 @@ def qr_r_no_pivot_impl(
x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype) geqrf = _LAPACK().numba_xgeqrf(dtype)
...@@ -714,14 +762,14 @@ def qr_r_no_pivot_impl( ...@@ -714,14 +762,14 @@ def qr_r_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_val = int(WORK.item()) lwork_val = int(WORK.item().real)
else: else:
lwork_val = lwork lwork_val = lwork
...@@ -731,7 +779,7 @@ def qr_r_no_pivot_impl( ...@@ -731,7 +779,7 @@ def qr_r_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
...@@ -755,6 +803,7 @@ def qr_raw_no_pivot_impl( ...@@ -755,6 +803,7 @@ def qr_raw_no_pivot_impl(
x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype) geqrf = _LAPACK().numba_xgeqrf(dtype)
...@@ -786,14 +835,14 @@ def qr_raw_no_pivot_impl( ...@@ -786,14 +835,14 @@ def qr_raw_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_val = int(WORK.item()) lwork_val = int(WORK.item().real)
else: else:
lwork_val = lwork lwork_val = lwork
...@@ -803,7 +852,7 @@ def qr_raw_no_pivot_impl( ...@@ -803,7 +852,7 @@ def qr_raw_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes, WORK.view(w_type).ctypes,
...@@ -826,7 +875,11 @@ def qr_raw_pivot_impl( ...@@ -826,7 +875,11 @@ def qr_raw_pivot_impl(
x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
is_complex = isinstance(dtype, Complex)
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype) geqp3 = _LAPACK().numba_xgeqp3(dtype)
...@@ -850,15 +903,31 @@ def qr_raw_pivot_impl( ...@@ -850,15 +903,31 @@ def qr_raw_pivot_impl(
K = min(M, N) K = min(M, N)
TAU = np.empty(K, dtype=dtype) TAU = np.empty(K, dtype=dtype)
JPVT = np.zeros(N, dtype=np.int32) JPVT = np.zeros(N, dtype=np.int32)
if is_complex:
RWORK = np.empty(2 * N, dtype=w_type)
if lwork is None: if lwork is None:
lwork = -1 lwork = -1
if lwork == -1: if lwork == -1:
WORK = np.empty(1, dtype=dtype) WORK = np.empty(1, dtype=dtype)
if is_complex:
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1), # LWORK
RWORK.ctypes,
val_to_int_ptr(1), # INFO
)
else:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
...@@ -866,17 +935,31 @@ def qr_raw_pivot_impl( ...@@ -866,17 +935,31 @@ def qr_raw_pivot_impl(
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
lwork_val = int(WORK.item()) lwork_val = int(WORK.item().real)
else: else:
lwork_val = lwork lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype) WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1) INFO = val_to_int_ptr(1)
if is_complex:
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
RWORK.ctypes,
INFO,
)
else:
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.view(w_type).ctypes, x_copy.T.view(w_type).T.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.view(w_type).ctypes,
......
...@@ -42,7 +42,6 @@ from pytensor.tensor.slinalg import ( ...@@ -42,7 +42,6 @@ from pytensor.tensor.slinalg import (
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
from pytensor.tensor.type import complex_dtypes, integer_dtypes
@numba_funcify.register(Cholesky) @numba_funcify.register(Cholesky)
...@@ -418,12 +417,12 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -418,12 +417,12 @@ def numba_funcify_QR(op, node, **kwargs):
pivoting = op.pivoting pivoting = op.pivoting
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype in_dtype = node.inputs[0].type.numpy_dtype
if dtype in complex_dtypes: integer_input = in_dtype.kind in "ibu"
return generate_fallback_impl(op, node=node, **kwargs) if integer_input and config.compiler_verbose:
print("QR requires casting discrete input to float") # noqa: T201
integer_input = dtype in integer_dtypes out_dtype = node.outputs[0].type.numpy_dtype
in_dtype = config.floatX if integer_input else dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def qr(a): def qr(a):
...@@ -434,7 +433,7 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -434,7 +433,7 @@ def numba_funcify_QR(op, node, **kwargs):
) )
if integer_input: if integer_input:
a = a.astype(in_dtype) a = a.astype(out_dtype)
if (mode == "full" or mode == "economic") and pivoting: if (mode == "full" or mode == "economic") and pivoting:
Q, R, P = _qr_full_pivot( Q, R, P = _qr_full_pivot(
......
...@@ -1824,7 +1824,10 @@ class QR(Op): ...@@ -1824,7 +1824,10 @@ class QR(Op):
K = None K = None
in_dtype = x.type.numpy_dtype in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}") if in_dtype.kind in "ibu":
out_dtype = "float64" if in_dtype.itemsize > 2 else "float32"
else:
out_dtype = "float64" if in_dtype.itemsize > 4 else "float32"
match self.mode: match self.mode:
case "full": case "full":
......
...@@ -718,17 +718,21 @@ class TestDecompositions: ...@@ -718,17 +718,21 @@ class TestDecompositions:
ids=["economic", "full_pivot", "r", "raw_pivot"], ids=["economic", "full_pivot", "r", "raw_pivot"],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] "overwrite_a", [False, True], ids=["overwrite_a", "no_overwrite"]
) )
def test_qr(self, mode, pivoting, overwrite_a): @pytest.mark.parametrize("complex", (False, True))
def test_qr(self, mode, pivoting, overwrite_a, complex):
shape = (5, 5) shape = (5, 5)
rng = np.random.default_rng() rng = np.random.default_rng()
A = pt.tensor( A = pt.tensor(
"A", "A",
shape=shape, shape=shape,
dtype=config.floatX, dtype="complex128" if complex else "float64",
) )
A_val = rng.normal(size=shape).astype(config.floatX) if complex:
A_val = rng.normal(size=(*shape, 2)).view(dtype=A.dtype).squeeze(-1)
else:
A_val = rng.normal(size=shape).astype(A.dtype)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting) qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论