提交 8ef585b6 authored 作者: Mateusz Sokół's avatar Mateusz Sokół 提交者: Thomas Wiecki

Add slogdet for JAX

上级 06e7afea
...@@ -3,7 +3,7 @@ import jax.numpy as jnp ...@@ -3,7 +3,7 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet
@jax_funcify.register(SVD) @jax_funcify.register(SVD)
...@@ -25,6 +25,14 @@ def jax_funcify_Det(op, **kwargs): ...@@ -25,6 +25,14 @@ def jax_funcify_Det(op, **kwargs):
return det return det
@jax_funcify.register(SLogDet)
def jax_funcify_SLogDet(op, **kwargs):
def slogdet(x):
return jnp.linalg.slogdet(x)
return slogdet
@jax_funcify.register(Eig) @jax_funcify.register(Eig)
def jax_funcify_Eig(op, **kwargs): def jax_funcify_Eig(op, **kwargs):
def eig(x): def eig(x):
......
...@@ -85,6 +85,10 @@ def test_jax_basic_multiout(): ...@@ -85,6 +85,10 @@ def test_jax_basic_multiout():
out_fg = FunctionGraph([x], outs) out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = at_nlinalg.slogdet(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
@pytest.mark.xfail( @pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"), version_parse(jax.__version__) >= version_parse("0.2.12"),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论