提交 e3290af0 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba solve_triangular

上级 e3f1e040
import numpy as np import numpy as np
from numba.core import types
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float from numba.core.types import Complex, Float
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
...@@ -46,16 +45,15 @@ def _solve_triangular( ...@@ -46,16 +45,15 @@ def _solve_triangular(
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b): def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular") _check_linalg_matrix(
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve_triangular") A, ndim=2, dtype=(Float, Complex), func_name="solve_triangular"
)
_check_linalg_matrix(
B, ndim=(1, 2), dtype=(Float, Complex), func_name="solve_triangular"
)
_check_dtypes_match((A, B), func_name="solve_triangular") _check_dtypes_match((A, B), func_name="solve_triangular")
dtype = A.dtype dtype = A.dtype
numba_trtrs = _LAPACK().numba_xtrtrs(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise TypeError(
"This function is not expected to work with complex numbers yet"
)
def impl(A, B, trans, lower, unit_diagonal, overwrite_b): def impl(A, B, trans, lower, unit_diagonal, overwrite_b):
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
...@@ -66,8 +64,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b): ...@@ -66,8 +64,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b):
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)): if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
A_f = A A_f = A
if A.flags.c_contiguous: if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous # A c_contiguous matrix reinterpreted as f_contiguous is A^T (plain transpose, no conjugation).
# Is this valid for complex matrices that were .conj().mT by PyTensor? # An upper/lower triangular A^T is lower/upper triangular, so we flip lower.
lower = not lower lower = not lower
trans = 1 - trans trans = 1 - trans
else: else:
......
...@@ -341,8 +341,6 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -341,8 +341,6 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs) A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c":
return generate_fallback_impl(op, node=node, **kwargs)
must_cast_A = A_dtype != out_dtype must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose: if must_cast_A and config.compiler_verbose:
print("SolveTriangular requires casting first input `A`") # noqa: T201 print("SolveTriangular requires casting first input `A`") # noqa: T201
......
...@@ -180,15 +180,10 @@ class TestSolves: ...@@ -180,15 +180,10 @@ class TestSolves:
is_complex: bool, is_complex: bool,
overwrite_b: bool, overwrite_b: bool,
): ):
if is_complex: complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, dtype = complex_dtype if is_complex else floatX
# why?
pytest.skip("Complex inputs currently not supported to solve_triangular")
def A_func(x): def A_func(x):
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
x = x @ x.conj().T x = x @ x.conj().T
x_tri = scipy.linalg.cholesky(x, lower=lower).astype(dtype) x_tri = scipy.linalg.cholesky(x, lower=lower).astype(dtype)
...@@ -197,12 +192,17 @@ class TestSolves: ...@@ -197,12 +192,17 @@ class TestSolves:
return x_tri return x_tri
A = pt.matrix("A", dtype=floatX) A = pt.matrix("A", dtype=dtype)
b = pt.tensor("b", shape=b_shape, dtype=floatX) b = pt.tensor("b", shape=b_shape, dtype=dtype)
rng = np.random.default_rng(418) rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX) A_base = rng.normal(size=(5, 5))
b_val = rng.normal(size=b_shape).astype(floatX) if is_complex:
A_base = A_base + 1j * rng.normal(size=(5, 5))
A_val = A_func(A_base).astype(dtype)
b_val = rng.normal(size=b_shape).astype(dtype)
if is_complex:
b_val = b_val + 1j * rng.normal(size=b_shape).astype(dtype)
X = pt.linalg.solve_triangular( X = pt.linalg.solve_triangular(
A, A,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论