提交 afabe83f authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add JAX dispatch for Schur

上级 62ba6c9f
...@@ -13,6 +13,7 @@ from pytensor.tensor.slinalg import ( ...@@ -13,6 +13,7 @@ from pytensor.tensor.slinalg import (
Expm, Expm,
LUFactor, LUFactor,
PivotToPermutations, PivotToPermutations,
Schur,
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
...@@ -183,3 +184,19 @@ def jax_funcify_Expm(op, **kwargs): ...@@ -183,3 +184,19 @@ def jax_funcify_Expm(op, **kwargs):
return jax.scipy.linalg.expm(x) return jax.scipy.linalg.expm(x)
return expm return expm
@jax_funcify.register(Schur)
def jax_funcify_Schur(op, **kwargs):
output = op.output
if op.sort is not None:
warnings.warn(
"jax.scipy.linalg.schur only supports sort=None. The sort argument is ignored."
)
def schur(a):
T, Z = jax.scipy.linalg.schur(a, output=output)
return T, Z
return schur
...@@ -382,3 +382,13 @@ def test_jax_qr(mode): ...@@ -382,3 +382,13 @@ def test_jax_qr(mode):
out = pt_slinalg.qr(A, mode=mode) out = pt_slinalg.qr(A, mode=mode)
compare_jax_and_py([A], out, [A_val]) compare_jax_and_py([A], out, [A_val])
@pytest.mark.parametrize("output", ["real", "complex"])
def test_jax_schur(output):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor(name="A", shape=(5, 5))
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
T, Z = pt_slinalg.schur(A, output=output)
compare_jax_and_py([A], [T, Z], [A_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论