Unverified 提交 f3d2ede9 authored 作者: Pham Nguyen Hung's avatar Pham Nguyen Hung 提交者: GitHub

Implemented JAX backend for Eigvalsh (#867)

上级 920b409b
import jax import jax
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
Eigvalsh,
Solve,
SolveTriangular,
)
@jax_funcify.register(Eigvalsh)
def jax_funcify_Eigvalsh(op, **kwargs):
if op.lower:
UPLO = "L"
else:
UPLO = "U"
def eigvalsh(a, b):
if b is not None:
raise NotImplementedError(
"jax.numpy.linalg.eigvalsh does not support generalized eigenvector problems (b != None)"
)
return jax.numpy.linalg.eigvalsh(a, UPLO=UPLO)
return eigvalsh
@jax_funcify.register(Cholesky) @jax_funcify.register(Cholesky)
......
...@@ -163,3 +163,34 @@ def test_jax_block_diag_blockwise(): ...@@ -163,3 +163,34 @@ def test_jax_block_diag_blockwise():
np.random.normal(size=(5, 3, 3)).astype(config.floatX), np.random.normal(size=(5, 3, 3)).astype(config.floatX),
], ],
) )
@pytest.mark.parametrize("lower", [False, True])
def test_jax_eigvalsh(lower):
A = matrix("A")
B = matrix("B")
out = pt_slinalg.eigvalsh(A, B, lower=lower)
out_fg = FunctionGraph([A, B], [out])
with pytest.raises(NotImplementedError):
compare_jax_and_py(
out_fg,
[
np.array(
[[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]
).astype(config.floatX),
np.array(
[[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]]
).astype(config.floatX),
],
)
compare_jax_and_py(
out_fg,
[
np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype(
config.floatX
),
None,
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论