提交 9a15b2ef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba QR: Support complex dtype inputs

上级 61380247
......@@ -3,6 +3,7 @@ import ctypes
import numpy as np
from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic
from numba.core.types import Complex
from numba.np.linalg import ensure_lapack, get_blas_kind
......@@ -486,8 +487,7 @@ class _LAPACK:
Used in QR decomposition with pivoting.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE(
None,
ctype_args = (
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
......@@ -496,8 +496,20 @@ class _LAPACK:
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
)
if isinstance(dtype, Complex):
ctype_args = (
*ctype_args,
float_pointer, # RWORK)
)
functype = ctypes.CFUNCTYPE(
None,
*ctype_args,
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
......
......@@ -42,7 +42,6 @@ from pytensor.tensor.slinalg import (
Solve,
SolveTriangular,
)
from pytensor.tensor.type import complex_dtypes, integer_dtypes
@numba_funcify.register(Cholesky)
......@@ -418,12 +417,12 @@ def numba_funcify_QR(op, node, **kwargs):
pivoting = op.pivoting
overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
return generate_fallback_impl(op, node=node, **kwargs)
in_dtype = node.inputs[0].type.numpy_dtype
integer_input = in_dtype.kind in "ibu"
if integer_input and config.compiler_verbose:
print("QR requires casting discrete input to float") # noqa: T201
integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype
out_dtype = node.outputs[0].type.numpy_dtype
@numba_basic.numba_njit
def qr(a):
......@@ -434,7 +433,7 @@ def numba_funcify_QR(op, node, **kwargs):
)
if integer_input:
a = a.astype(in_dtype)
a = a.astype(out_dtype)
if (mode == "full" or mode == "economic") and pivoting:
Q, R, P = _qr_full_pivot(
......
......@@ -1824,7 +1824,10 @@ class QR(Op):
K = None
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
if in_dtype.kind in "ibu":
out_dtype = "float64" if in_dtype.itemsize > 2 else "float32"
else:
out_dtype = "float64" if in_dtype.itemsize > 4 else "float32"
match self.mode:
case "full":
......
......@@ -718,17 +718,21 @@ class TestDecompositions:
ids=["economic", "full_pivot", "r", "raw_pivot"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
"overwrite_a", [False, True], ids=["overwrite_a", "no_overwrite"]
)
def test_qr(self, mode, pivoting, overwrite_a):
@pytest.mark.parametrize("complex", (False, True))
def test_qr(self, mode, pivoting, overwrite_a, complex):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
dtype="complex128" if complex else "float64",
)
A_val = rng.normal(size=shape).astype(config.floatX)
if complex:
A_val = rng.normal(size=(*shape, 2)).view(dtype=A.dtype).squeeze(-1)
else:
A_val = rng.normal(size=shape).astype(A.dtype)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论