Unverified 提交 2ea5a544 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: GitHub

Add jax implementation of `pt.linalg.pinv` (#294)

上级 a99a7b22
...@@ -3,7 +3,16 @@ import jax.numpy as jnp ...@@ -3,7 +3,16 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet from pytensor.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
MatrixInverse,
MatrixPinv,
QRFull,
SLogDet,
)
@jax_funcify.register(SVD) @jax_funcify.register(SVD)
...@@ -77,6 +86,14 @@ def jax_funcify_Dot(op, **kwargs): ...@@ -77,6 +86,14 @@ def jax_funcify_Dot(op, **kwargs):
return dot return dot
@jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs):
def pinv(x):
return jnp.linalg.pinv(x)
return pinv
@jax_funcify.register(BatchedDot) @jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op, **kwargs): def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b): def batched_dot(a, b):
......
...@@ -134,3 +134,12 @@ def test_tensor_basics(): ...@@ -134,3 +134,12 @@ def test_tensor_basics():
out = at_max(y) out = at_max(y)
fgraph = FunctionGraph([y], [out]) fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_pinv():
x = matrix("x")
x_inv = at_nlinalg.pinv(x)
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]])
compare_jax_and_py(fgraph, [x_np])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论