提交 8ef585b6 authored 作者: Mateusz Sokół's avatar Mateusz Sokół 提交者: Thomas Wiecki

Add slogdet for JAX

上级 06e7afea
......@@ -3,7 +3,7 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet
@jax_funcify.register(SVD)
......@@ -25,6 +25,14 @@ def jax_funcify_Det(op, **kwargs):
return det
@jax_funcify.register(SLogDet)
def jax_funcify_SLogDet(op, **kwargs):
def slogdet(x):
return jnp.linalg.slogdet(x)
return slogdet
@jax_funcify.register(Eig)
def jax_funcify_Eig(op, **kwargs):
def eig(x):
......
......@@ -85,6 +85,10 @@ def test_jax_basic_multiout():
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = at_nlinalg.slogdet(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论