提交 9f911e35 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add Numba dispatch for QZ

上级 ad8dca48
...@@ -25,6 +25,16 @@ from pytensor.link.numba.dispatch.linalg.decomposition.qr import ( ...@@ -25,6 +25,16 @@ from pytensor.link.numba.dispatch.linalg.decomposition.qr import (
_qr_raw_no_pivot, _qr_raw_no_pivot,
_qr_raw_pivot, _qr_raw_pivot,
) )
from pytensor.link.numba.dispatch.linalg.decomposition.qz import (
_qz_complex_nosort_eig,
_qz_complex_nosort_noeig,
_qz_complex_sort_eig,
_qz_complex_sort_noeig,
_qz_real_nosort_eig,
_qz_real_nosort_noeig,
_qz_real_sort_eig,
_qz_real_sort_noeig,
)
from pytensor.link.numba.dispatch.linalg.decomposition.schur import ( from pytensor.link.numba.dispatch.linalg.decomposition.schur import (
schur_complex, schur_complex,
schur_real, schur_real,
...@@ -46,6 +56,7 @@ from pytensor.tensor._linalg.solve.linear_control import TRSYL ...@@ -46,6 +56,7 @@ from pytensor.tensor._linalg.solve.linear_control import TRSYL
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR, QR,
QZ,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -535,6 +546,94 @@ def numba_funcify_Schur(op, node, **kwargs): ...@@ -535,6 +546,94 @@ def numba_funcify_Schur(op, node, **kwargs):
return schur, cache_version return schur, cache_version
@register_funcify_default_op_cache_key(QZ)
def numba_funcify_QZ(op, node, **kwargs):
complex_output = op.complex_output
sort = op.sort
return_eigenvalues = op.return_eigenvalues
overwrite_a = op.overwrite_a
overwrite_b = op.overwrite_b
in_dtype_a = node.inputs[0].type.numpy_dtype
in_dtype_b = node.inputs[1].type.numpy_dtype
out_dtype = node.outputs[0].type.numpy_dtype
integer_input_a = in_dtype_a.kind in "ibu"
integer_input_b = in_dtype_b.kind in "ibu"
complex_input = in_dtype_a.kind == "c" or in_dtype_b.kind == "c"
needs_complex_cast = (
in_dtype_a.kind in "fd" or in_dtype_b.kind in "fd"
) and complex_output
# Disable overwrite for dtype conversion (real->complex upcast)
if needs_complex_cast:
overwrite_a = False
overwrite_b = False
if config.compiler_verbose:
print( # noqa: T201
"QZ: disabling overwrite_a/b due to dtype conversion (casting prevents in-place operation)"
)
if (integer_input_a or integer_input_b) and config.compiler_verbose:
print("QZ requires casting discrete input to float") # noqa: T201
use_complex = complex_input or complex_output
use_sort = sort is not None
if use_complex:
if use_sort:
if return_eigenvalues:
qz_fn = _qz_complex_sort_eig
else:
qz_fn = _qz_complex_sort_noeig
else:
if return_eigenvalues:
qz_fn = _qz_complex_nosort_eig
else:
qz_fn = _qz_complex_nosort_noeig
else:
if use_sort:
if return_eigenvalues:
qz_fn = _qz_real_sort_eig
else:
qz_fn = _qz_real_sort_noeig
else:
if return_eigenvalues:
qz_fn = _qz_real_nosort_eig
else:
qz_fn = _qz_real_nosort_noeig
if use_sort:
@numba_basic.numba_njit
def qz(a, b):
if integer_input_a:
a = a.astype(out_dtype)
elif needs_complex_cast:
a = a.astype(out_dtype)
if integer_input_b:
b = b.astype(out_dtype)
elif needs_complex_cast:
b = b.astype(out_dtype)
return qz_fn(a, b, sort, overwrite_a, overwrite_b)
else:
@numba_basic.numba_njit
def qz(a, b):
if integer_input_a:
a = a.astype(out_dtype)
elif needs_complex_cast:
a = a.astype(out_dtype)
if integer_input_b:
b = b.astype(out_dtype)
elif needs_complex_cast:
b = b.astype(out_dtype)
return qz_fn(a, b, overwrite_a, overwrite_b)
cache_version = 1
return qz, cache_version
@register_funcify_default_op_cache_key(TRSYL) @register_funcify_default_op_cache_key(TRSYL)
def numba_funcify_TRSYL(op, node, **kwargs): def numba_funcify_TRSYL(op, node, **kwargs):
in_dtype_a = node.inputs[0].type.numpy_dtype in_dtype_a = node.inputs[0].type.numpy_dtype
......
...@@ -20,6 +20,7 @@ from pytensor.tensor.slinalg import ( ...@@ -20,6 +20,7 @@ from pytensor.tensor.slinalg import (
lu, lu,
lu_factor, lu_factor,
lu_solve, lu_solve,
qz,
schur, schur,
solve, solve,
solve_triangular, solve_triangular,
...@@ -793,6 +794,103 @@ class TestDecompositions: ...@@ -793,6 +794,103 @@ class TestDecompositions:
np.testing.assert_allclose(Z_c, Z_res, atol=1e-6) np.testing.assert_allclose(Z_c, Z_res, atol=1e-6)
np.testing.assert_allclose(val_c_contig, A_val) np.testing.assert_allclose(val_c_contig, A_val)
@pytest.mark.parametrize(
"output, input_type, sort, return_eigenvalues",
[
("real", "real", None, False),
("complex", "real", "lhp", True),
("real", "complex", "ouc", False),
("complex", "complex", None, True),
("real", "real", "iuc", True),
],
ids=[
"real_nosort",
"real_to_complex_sort",
"complex_sort",
"complex_nosort_eig",
"real_sort_eig",
],
)
def test_qz(self, output, input_type, sort, return_eigenvalues):
shape = (5, 5)
dtype = (
config.floatX
if input_type == "real"
else ("complex64" if config.floatX.endswith("32") else "complex128")
)
A = pt.tensor("A", shape=shape, dtype=dtype)
B = pt.tensor("B", shape=shape, dtype=dtype)
outputs = qz(
A, B, output=output, sort=sort, return_eigenvalues=return_eigenvalues
)
if return_eigenvalues:
AA, BB, alpha, beta, Q, Z = outputs
output_list = [AA, BB, alpha, beta, Q, Z]
else:
AA, BB, Q, Z = outputs
output_list = [AA, BB, Q, Z]
rng = np.random.default_rng()
A_val = rng.normal(size=shape).astype(dtype)
B_val = rng.normal(size=shape).astype(dtype)
fn, res = compare_numba_and_py(
[A, B],
output_list,
[A_val, B_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
if return_eigenvalues:
AA_res, BB_res, alpha_res, beta_res, Q_res, Z_res = res
else:
AA_res, BB_res, Q_res, Z_res = res
expected_complex_output = input_type == "complex" or output == "complex"
assert np.iscomplexobj(AA_res) == expected_complex_output
assert np.iscomplexobj(BB_res) == expected_complex_output
assert np.iscomplexobj(Q_res) == expected_complex_output
assert np.iscomplexobj(Z_res) == expected_complex_output
# Verify reconstruction: Q @ AA @ Z.conj().T = A, Q @ BB @ Z.conj().T = B
A_rebuilt = Q_res @ AA_res @ Z_res.conj().T
B_rebuilt = Q_res @ BB_res @ Z_res.conj().T
np.testing.assert_allclose(A_val, A_rebuilt, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(B_val, B_rebuilt, atol=1e-5, rtol=1e-5)
# Test F-contiguous input
A_val_f_contig = np.copy(A_val, order="F")
B_val_f_contig = np.copy(B_val, order="F")
res_f = fn(A_val_f_contig, B_val_f_contig)
if return_eigenvalues:
AA_f, BB_f, alpha_f, beta_f, Q_f, Z_f = res_f
np.testing.assert_allclose(alpha_f, alpha_res, atol=1e-6)
np.testing.assert_allclose(beta_f, beta_res, atol=1e-6)
else:
AA_f, BB_f, Q_f, Z_f = res_f
np.testing.assert_allclose(AA_f, AA_res, atol=1e-6)
np.testing.assert_allclose(BB_f, BB_res, atol=1e-6)
np.testing.assert_allclose(Q_f, Q_res, atol=1e-6)
np.testing.assert_allclose(Z_f, Z_res, atol=1e-6)
# Test C-contiguous input
A_val_c_contig = np.copy(A_val, order="C")
B_val_c_contig = np.copy(B_val, order="C")
res_c = fn(A_val_c_contig, B_val_c_contig)
if return_eigenvalues:
AA_c, BB_c, alpha_c, beta_c, Q_c, Z_c = res_c
np.testing.assert_allclose(alpha_c, alpha_res, atol=1e-6)
np.testing.assert_allclose(beta_c, beta_res, atol=1e-6)
else:
AA_c, BB_c, Q_c, Z_c = res_c
np.testing.assert_allclose(AA_c, AA_res, atol=1e-6)
np.testing.assert_allclose(BB_c, BB_res, atol=1e-6)
np.testing.assert_allclose(Q_c, Q_res, atol=1e-6)
np.testing.assert_allclose(Z_c, Z_res, atol=1e-6)
def test_block_diag(): def test_block_diag():
A = pt.matrix("A") A = pt.matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论