提交 e3290af0 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba solve_triangular

上级 e3f1e040
import numpy as np
from numba.core import types
from numba.core.extending import overload
from numba.core.types import Float
from numba.core.types import Complex, Float
from numba.np.linalg import ensure_lapack
from scipy import linalg
......@@ -46,16 +45,15 @@ def _solve_triangular(
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve_triangular")
_check_linalg_matrix(
A, ndim=2, dtype=(Float, Complex), func_name="solve_triangular"
)
_check_linalg_matrix(
B, ndim=(1, 2), dtype=(Float, Complex), func_name="solve_triangular"
)
_check_dtypes_match((A, B), func_name="solve_triangular")
dtype = A.dtype
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise TypeError(
"This function is not expected to work with complex numbers yet"
)
def impl(A, B, trans, lower, unit_diagonal, overwrite_b):
_N = np.int32(A.shape[-1])
......@@ -66,8 +64,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b):
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
A_f = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
# Is this valid for complex matrices that were .conj().mT by PyTensor?
# A c_contiguous matrix reinterpreted as f_contiguous is A^T (plain transpose, no conjugation).
# An upper/lower triangular A^T is lower/upper triangular, so we flip lower.
lower = not lower
trans = 1 - trans
else:
......
......@@ -341,8 +341,6 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c":
return generate_fallback_impl(op, node=node, **kwargs)
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("SolveTriangular requires casting first input `A`") # noqa: T201
......
......@@ -180,15 +180,10 @@ class TestSolves:
is_complex: bool,
overwrite_b: bool,
):
if is_complex:
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
# why?
pytest.skip("Complex inputs currently not supported to solve_triangular")
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
def A_func(x):
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
x = x @ x.conj().T
x_tri = scipy.linalg.cholesky(x, lower=lower).astype(dtype)
......@@ -197,12 +192,17 @@ class TestSolves:
return x_tri
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
A = pt.matrix("A", dtype=dtype)
b = pt.tensor("b", shape=b_shape, dtype=dtype)
rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
A_base = rng.normal(size=(5, 5))
if is_complex:
A_base = A_base + 1j * rng.normal(size=(5, 5))
A_val = A_func(A_base).astype(dtype)
b_val = rng.normal(size=b_shape).astype(dtype)
if is_complex:
b_val = b_val + 1j * rng.normal(size=b_shape).astype(dtype)
X = pt.linalg.solve_triangular(
A,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论