Unverified 提交 2ea5a544 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: GitHub

Add jax implementation of `pt.linalg.pinv` (#294)

上级 a99a7b22
......@@ -3,7 +3,16 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet
from pytensor.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
MatrixInverse,
MatrixPinv,
QRFull,
SLogDet,
)
@jax_funcify.register(SVD)
......@@ -77,6 +86,14 @@ def jax_funcify_Dot(op, **kwargs):
return dot
@jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs):
def pinv(x):
return jnp.linalg.pinv(x)
return pinv
@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
......
......@@ -134,3 +134,12 @@ def test_tensor_basics():
out = at_max(y)
fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_pinv():
x = matrix("x")
x_inv = at_nlinalg.pinv(x)
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]])
compare_jax_and_py(fgraph, [x_np])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论