提交 6499a2c1 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba solve

上级 b37dd14b
......@@ -2,7 +2,7 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.core.types import Complex, Float
from numba.np.linalg import ensure_lapack
from scipy import linalg
......@@ -47,8 +47,8 @@ def solve_gen_impl(
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="solve")
_check_dtypes_match((A, B), "solve")
def impl(
......
......@@ -2,7 +2,7 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
......@@ -51,10 +51,11 @@ def solve_psd_impl(
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
numba_posv = _LAPACK().numba_xposv(A.dtype)
is_complex = isinstance(A.dtype, Complex)
def impl(
A: np.ndarray,
......@@ -67,10 +68,12 @@ def solve_psd_impl(
_solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1])
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
if overwrite_a and A.flags.f_contiguous:
A_copy = A
elif not is_complex and overwrite_a and A.flags.c_contiguous:
# For real symmetric matrices, c_contiguous A^T = A, so flipping lower is valid.
# Not valid for complex Hermitian where A^T = conj(A) != A.
A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)
......
......@@ -2,7 +2,7 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
......@@ -51,8 +51,8 @@ def solve_symmetric_impl(
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
dtype = A.dtype
numba_sysv = _LAPACK().numba_xsysv(A.dtype)
......
......@@ -80,8 +80,6 @@ def numba_funcify_Cholesky(op, node, **kwargs):
overwrite_a = op.overwrite_a
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
return generate_fallback_impl(op, node=node, **kwargs)
discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose:
print("Cholesky requires casting discrete input to float") # noqa: T201
......@@ -281,8 +279,8 @@ def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c":
return generate_fallback_impl(op, node=node, **kwargs)
assume_a = op.assume_a
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("Solve requires casting first input `A`") # noqa: T201
......@@ -378,8 +376,6 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
c_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
if c_dtype.kind == "c" or b_dtype.kind == "c":
return generate_fallback_impl(op, node=node, **kwargs)
must_cast_c = c_dtype != out_dtype
if must_cast_c and config.compiler_verbose:
print("CholeskySolve requires casting first input `c`") # noqa: T201
......
......@@ -50,6 +50,7 @@ class TestSolves:
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos", "tridiagonal"], ids=str)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_solve(
self,
b_shape: tuple[int],
......@@ -57,14 +58,18 @@ class TestSolves:
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
is_complex: bool,
):
if assume_a not in ("sym", "her", "pos", "tridiagonal") and not lower:
# Avoid redundant tests with lower=True and lower=False for non symmetric matrices
pytest.skip("Skipping redundant test already covered by lower=True")
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
def A_func(x):
if assume_a == "pos":
x = x @ x.T
x = x @ x.conj().T
x = np.tril(x) if lower else np.triu(x)
elif assume_a == "sym":
x = (x + x.T) / 2
......@@ -82,12 +87,17 @@ class TestSolves:
x[arange_n[:-1], arange_n[1:]] = np.diag(_x, k=1)
return x
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
A = pt.matrix("A", dtype=dtype)
b = pt.tensor("b", shape=b_shape, dtype=dtype)
rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
A_base = rng.normal(size=(5, 5))
if is_complex:
A_base = A_base + 1j * rng.normal(size=(5, 5))
A_val = A_func(A_base).astype(dtype)
b_val = rng.normal(size=b_shape).astype(dtype)
if is_complex:
b_val = b_val + 1j * rng.normal(size=b_shape).astype(dtype)
X = pt.linalg.solve(
A,
......@@ -139,8 +149,12 @@ class TestSolves:
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
# We can destroy C-contiguous A arrays by inverting `tranpose/lower` at runtime
assert np.allclose(A_val_c_contig, A_val) == (not overwrite_a)
# We can destroy C-contiguous A arrays by inverting `transpose/lower` at runtime
# Complex posdef/hermitian can't use this trick (A^T = conj(A) != A for Hermitian)
can_destroy_c_contig_A = overwrite_a and not (
is_complex and assume_a in ("pos",)
)
assert np.allclose(A_val_c_contig, A_val) == (not can_destroy_c_contig_A)
# b vectors are always f_contiguous if also c_contiguous
assert np.allclose(b_val_c_contig, b_val) == (
not (overwrite_b and b_val_c_contig.flags.f_contiguous)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论