提交 ad8dca48 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Implement QZ Op

上级 9834e96c
差异被折叠。
......@@ -37,6 +37,7 @@ from pytensor.tensor.slinalg import (
lu_solve,
pivot_to_permutation,
qr,
qz,
schur,
solve,
solve_triangular,
......@@ -1366,3 +1367,103 @@ class TestSchur:
assert Z_out.size == 0
assert T_out.dtype == config.floatX
assert Z_out.dtype == config.floatX
class TestQZ:
@pytest.mark.parametrize(
"shape, output",
[((5, 5), "real"), ((5, 5), "complex"), ((2, 4, 4), "real")],
ids=["not_batched_real", "not_batched_complex", "batched_real"],
)
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
@pytest.mark.parametrize("sort", [None, "lhp", "rhp", "iuc", "ouc"])
def test_qz_decomposition(self, shape, output, complex, sort):
dtype = (
config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
)
A = tensor("A", shape=shape, dtype=dtype)
B = tensor("B", shape=shape, dtype=dtype)
outputs = qz(
A, B, output=output, sort=sort, return_eigenvalues=sort is not None
)
f = function([A, B], outputs)
rng = np.random.default_rng(utt.fetch_seed())
A_val, B_val = rng.normal(size=(2, *shape))
A_val = A_val.astype(config.floatX)
B_val = B_val.astype(config.floatX)
if complex:
A_val = A_val + 1j * rng.normal(size=shape).astype(config.floatX)
B_val = B_val + 1j * rng.normal(size=shape).astype(config.floatX)
output_values = f(A_val, B_val)
if sort is None:
AA_val, BB_val, Q_val, Z_val = output_values
else:
AA_val, BB_val, alpha_val, beta_val, Q_val, Z_val = output_values
# Verify reconstruction
A_rebuilt = np.einsum("...ij,...jk,...lk->...il", Q_val, AA_val, Z_val.conj())
B_rebuilt = np.einsum("...ij,...jk,...lk->...il", Q_val, BB_val, Z_val.conj())
np.testing.assert_allclose(
A_val,
A_rebuilt,
atol=1e-6 if config.floatX == "float64" else 1e-3,
rtol=1e-6 if config.floatX == "float64" else 1e-3,
)
np.testing.assert_allclose(
B_val,
B_rebuilt,
atol=1e-6 if config.floatX == "float64" else 1e-3,
rtol=1e-6 if config.floatX == "float64" else 1e-3,
)
scipy_fn = (
scipy_linalg.qz
if sort is None
else functools.partial(scipy_linalg.ordqz, sort=sort)
)
scipy_signature = (
"(m,m),(m,m)->(m,m),(m,m),(m,m),(m,m)"
if sort is None
else ("(m,m),(m,m)->(m,m),(m,m),(m),(m),(m,m),(m,m)")
)
vec_qz = np.vectorize(
lambda a, b: scipy_fn(a, b, output=output),
signature=scipy_signature,
)
scipy_result = vec_qz(A_val, B_val)
if sort is None:
scipy_AA, scipy_BB, scipy_Q, scipy_Z = scipy_result
else:
scipy_AA, scipy_BB, scipy_alpha, scipy_beta, scipy_Q, scipy_Z = scipy_result
np.testing.assert_allclose(AA_val, scipy_AA, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(BB_val, scipy_BB, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Q_val, scipy_Q, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Z_val, scipy_Z, atol=1e-6, rtol=1e-6)
if sort is not None:
np.testing.assert_allclose(alpha_val, scipy_alpha, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(beta_val, scipy_beta, atol=1e-6, rtol=1e-6)
if len(shape) == 2 and (output == "complex") == complex:
A_f = np.asfortranarray(A_val.copy())
B_f = np.asfortranarray(B_val.copy())
f_mut = function(
[In(A, mutable=True), In(B, mutable=True)],
outputs,
mode=get_default_mode().including("inplace"),
)
f_mut(A_f, B_f)
np.testing.assert_allclose(A_f, scipy_AA, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(B_f, scipy_BB, atol=1e-6, rtol=1e-6)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论