提交 e60cfb57 authored 作者: Rob Zinkov's avatar Rob Zinkov 提交者: Ricardo Vieira

Add JAX dispatch for SolveTriangular OP, resolves #846

上级 ef256e52
...@@ -51,7 +51,7 @@ from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull ...@@ -51,7 +51,7 @@ from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -848,6 +848,26 @@ def jax_funcify_Solve(op, **kwargs): ...@@ -848,6 +848,26 @@ def jax_funcify_Solve(op, **kwargs):
return solve return solve
@jax_funcify.register(SolveTriangular)
def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
trans = op.trans
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
def solve_triangular(A, b):
return jsp.linalg.solve_triangular(
A,
b,
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)
return solve_triangular
@jax_funcify.register(Det) @jax_funcify.register(Det)
def jax_funcify_Det(op, **kwargs): def jax_funcify_Det(op, **kwargs):
def det(x): def det(x):
......
...@@ -315,6 +315,30 @@ def test_jax_basic(): ...@@ -315,6 +315,30 @@ def test_jax_basic():
) )
@pytest.mark.parametrize("check_finite", [False, True])
@pytest.mark.parametrize("lower", [False, True])
@pytest.mark.parametrize("trans", [0, 1, 2])
def test_jax_SolveTriangular(trans, lower, check_finite):
x = matrix("x")
b = vector("b")
out = at_slinalg.solve_triangular(
x,
b,
trans=trans,
lower=lower,
check_finite=check_finite,
)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, y, x_val, y_val", "x, y, x_val, y_val",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论