提交 9a15b2ef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba QR: Support complex dtype inputs

上级 61380247
...@@ -3,6 +3,7 @@ import ctypes ...@@ -3,6 +3,7 @@ import ctypes
import numpy as np import numpy as np
from numba.core import cgutils, types from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic from numba.core.extending import get_cython_function_address, intrinsic
from numba.core.types import Complex
from numba.np.linalg import ensure_lapack, get_blas_kind from numba.np.linalg import ensure_lapack, get_blas_kind
...@@ -486,8 +487,7 @@ class _LAPACK: ...@@ -486,8 +487,7 @@ class _LAPACK:
Used in QR decomposition with pivoting. Used in QR decomposition with pivoting.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3") lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE( ctype_args = (
None,
_ptr_int, # M _ptr_int, # M
_ptr_int, # N _ptr_int, # N
float_pointer, # A float_pointer, # A
...@@ -496,8 +496,20 @@ class _LAPACK: ...@@ -496,8 +496,20 @@ class _LAPACK:
float_pointer, # TAU float_pointer, # TAU
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # LWORK _ptr_int, # LWORK
)
if isinstance(dtype, Complex):
ctype_args = (
*ctype_args,
float_pointer, # RWORK)
)
functype = ctypes.CFUNCTYPE(
None,
*ctype_args,
_ptr_int, # INFO _ptr_int, # INFO
) )
return functype(lapack_ptr) return functype(lapack_ptr)
@classmethod @classmethod
......
...@@ -42,7 +42,6 @@ from pytensor.tensor.slinalg import ( ...@@ -42,7 +42,6 @@ from pytensor.tensor.slinalg import (
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
from pytensor.tensor.type import complex_dtypes, integer_dtypes
@numba_funcify.register(Cholesky) @numba_funcify.register(Cholesky)
...@@ -418,12 +417,12 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -418,12 +417,12 @@ def numba_funcify_QR(op, node, **kwargs):
pivoting = op.pivoting pivoting = op.pivoting
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype in_dtype = node.inputs[0].type.numpy_dtype
if dtype in complex_dtypes: integer_input = in_dtype.kind in "ibu"
return generate_fallback_impl(op, node=node, **kwargs) if integer_input and config.compiler_verbose:
print("QR requires casting discrete input to float") # noqa: T201
integer_input = dtype in integer_dtypes out_dtype = node.outputs[0].type.numpy_dtype
in_dtype = config.floatX if integer_input else dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def qr(a): def qr(a):
...@@ -434,7 +433,7 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -434,7 +433,7 @@ def numba_funcify_QR(op, node, **kwargs):
) )
if integer_input: if integer_input:
a = a.astype(in_dtype) a = a.astype(out_dtype)
if (mode == "full" or mode == "economic") and pivoting: if (mode == "full" or mode == "economic") and pivoting:
Q, R, P = _qr_full_pivot( Q, R, P = _qr_full_pivot(
......
...@@ -1824,7 +1824,10 @@ class QR(Op): ...@@ -1824,7 +1824,10 @@ class QR(Op):
K = None K = None
in_dtype = x.type.numpy_dtype in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}") if in_dtype.kind in "ibu":
out_dtype = "float64" if in_dtype.itemsize > 2 else "float32"
else:
out_dtype = "float64" if in_dtype.itemsize > 4 else "float32"
match self.mode: match self.mode:
case "full": case "full":
......
...@@ -718,17 +718,21 @@ class TestDecompositions: ...@@ -718,17 +718,21 @@ class TestDecompositions:
ids=["economic", "full_pivot", "r", "raw_pivot"], ids=["economic", "full_pivot", "r", "raw_pivot"],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] "overwrite_a", [False, True], ids=["overwrite_a", "no_overwrite"]
) )
def test_qr(self, mode, pivoting, overwrite_a): @pytest.mark.parametrize("complex", (False, True))
def test_qr(self, mode, pivoting, overwrite_a, complex):
shape = (5, 5) shape = (5, 5)
rng = np.random.default_rng() rng = np.random.default_rng()
A = pt.tensor( A = pt.tensor(
"A", "A",
shape=shape, shape=shape,
dtype=config.floatX, dtype="complex128" if complex else "float64",
) )
A_val = rng.normal(size=shape).astype(config.floatX) if complex:
A_val = rng.normal(size=(*shape, 2)).view(dtype=A.dtype).squeeze(-1)
else:
A_val = rng.normal(size=shape).astype(A.dtype)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting) qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论