提交 501ae605 authored 作者: Ben F. Maier's avatar Ben F. Maier 提交者: Jesse Grabowski

added tests for tridiagonal solve

上级 032ffa29
...@@ -122,6 +122,38 @@ def test_jax_solve(): ...@@ -122,6 +122,38 @@ def test_jax_solve():
) )
def test_jax_tridiagonal_solve():
N = 10
A = pt.matrix("A", shape=(N, N))
b = pt.vector("b", shape=(N,))
out = pt.linalg.solve(A, b, assume_a="tridiagonal")
A_val = np.eye(N)
for i in range(N - 1):
A_val[i, i + 1] = np.random.randn()
A_val[i + 1, i] = np.random.randn()
b_val = np.random.randn(N)
compare_jax_and_py(
[A, b],
[out],
[A_val, b_val],
)
b_ = pt.matrix("b", shape=(N, 2))
out = pt.linalg.solve(A, b_, assume_a="tridiagonal")
b_val = np.random.randn(N, 2)
compare_jax_and_py(
[A, b_],
[out],
[A_val, b_val],
)
def test_jax_SolveTriangular(): def test_jax_SolveTriangular():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论