提交 e60cfb57 authored 作者: Rob Zinkov's avatar Rob Zinkov 提交者: Ricardo Vieira

Add JAX dispatch for SolveTriangular OP, resolves #846

上级 ef256e52
......@@ -51,7 +51,7 @@ from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -848,6 +848,26 @@ def jax_funcify_Solve(op, **kwargs):
return solve
@jax_funcify.register(SolveTriangular)
def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
trans = op.trans
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
def solve_triangular(A, b):
return jsp.linalg.solve_triangular(
A,
b,
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)
return solve_triangular
@jax_funcify.register(Det)
def jax_funcify_Det(op, **kwargs):
def det(x):
......
......@@ -315,6 +315,30 @@ def test_jax_basic():
)
@pytest.mark.parametrize("check_finite", [False, True])
@pytest.mark.parametrize("lower", [False, True])
@pytest.mark.parametrize("trans", [0, 1, 2])
def test_jax_SolveTriangular(trans, lower, check_finite):
x = matrix("x")
b = vector("b")
out = at_slinalg.solve_triangular(
x,
b,
trans=trans,
lower=lower,
check_finite=check_finite,
)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
)
@pytest.mark.parametrize(
"x, y, x_val, y_val",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论