提交 2ba89371 authored 作者: Ben F. Maier's avatar Ben F. Maier 提交者: Jesse Grabowski

incorporate changes as asked for

上级 f146af68
......@@ -46,6 +46,7 @@ def jax_funcify_Cholesky(op, **kwargs):
def jax_funcify_Solve(op, **kwargs):
assume_a = op.assume_a
lower = op.lower
b_is_vec = op.b_ndim == 1
if assume_a == "tridiagonal":
# jax.scipy.solve does not yet support tridiagonal matrices
......@@ -54,20 +55,19 @@ def jax_funcify_Solve(op, **kwargs):
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
# jax requires dl and du to have the same shape as d
dl = jax.numpy.pad(dl, (1, 0))
du = jax.numpy.pad(du, (0, 1))
# jax also requires b to be a matrix; reshape it to be a column vector if necessary
b_is_vec = len(b.shape) == 1
if b_is_vec:
b = jax.numpy.expand_dims(b, -1)
res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b)
if b_is_vec:
# if b is a vector, return a vector
return res.flatten()
else:
return jax.numpy.squeeze(res, -1)
return res
else:
......
......@@ -122,19 +122,59 @@ def test_jax_solve():
)
def test_jax_tridiagonal_solve():
N = 10
A = pt.matrix("A", shape=(N, N))
b = pt.vector("b", shape=(N,))
@pytest.mark.parametrize(
"A_size, b_size, b_ndim",
[
(
(
5,
5,
),
(5,),
1,
),
(
(
5,
5,
),
(5, 1),
2,
),
(
(
5,
5,
),
(1, 5),
1,
),
(
(
4,
5,
5,
),
(4, 5, 5),
2,
),
],
ids=["basic_vector", "basic_matrix", "vector_broadcasted", "fully_batched"],
)
def test_jax_tridiagonal_solve(A_size: tuple, b_size: tuple, b_ndim: int):
A = pt.tensor("A", shape=A_size)
b = pt.tensor("b", shape=b_size)
out = pt.linalg.solve(A, b, assume_a="tridiagonal")
out = pt.linalg.solve(A, b, assume_a="tridiagonal", b_ndim=b_ndim)
A_val = np.eye(N)
A_val = np.zeros(A_size)
N = A_size[-1]
A_val[...] = np.eye(N)
for i in range(N - 1):
A_val[i, i + 1] = np.random.randn()
A_val[i + 1, i] = np.random.randn()
A_val[..., i, i + 1] = np.random.randn()
A_val[..., i + 1, i] = np.random.randn()
b_val = np.random.randn(N)
b_val = np.random.randn(*b_size)
compare_jax_and_py(
[A, b],
......@@ -142,17 +182,6 @@ def test_jax_tridiagonal_solve():
[A_val, b_val],
)
b_ = pt.matrix("b", shape=(N, 2))
out = pt.linalg.solve(A, b_, assume_a="tridiagonal")
b_val = np.random.randn(N, 2)
compare_jax_and_py(
[A, b_],
[out],
[A_val, b_val],
)
def test_jax_SolveTriangular():
rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论