提交 032ffa29 authored 作者: Ben F. Maier's avatar Ben F. Maier 提交者: Jesse Grabowski

fix shape issues in jax tridiagonal solve; close #1413

上级 6d236f14
...@@ -54,7 +54,21 @@ def jax_funcify_Solve(op, **kwargs): ...@@ -54,7 +54,21 @@ def jax_funcify_Solve(op, **kwargs):
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1) dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1) d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1) du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower) # jax requires dl and du to have the same shape as d
dl = jax.numpy.pad(dl, (1, 0))
du = jax.numpy.pad(du, (0, 1))
# if b is a vector, broadcast it to be a matrix
b_is_vec = len(b.shape) == 1
if b_is_vec:
b = jax.numpy.expand_dims(b, -1)
res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b)
if b_is_vec:
# if b is a vector, return a vector
return res.flatten()
else:
return res
else: else:
if assume_a not in ("gen", "sym", "her", "pos"): if assume_a not in ("gen", "sym", "her", "pos"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论