提交 f146af68 authored 作者: Benjamin F. Maier's avatar Benjamin F. Maier 提交者: Jesse Grabowski

Update pytensor/link/jax/dispatch/slinalg.py

上级 501ae605
...@@ -57,7 +57,7 @@ def jax_funcify_Solve(op, **kwargs): ...@@ -57,7 +57,7 @@ def jax_funcify_Solve(op, **kwargs):
# jax requires dl and du to have the same shape as d # jax requires dl and du to have the same shape as d
dl = jax.numpy.pad(dl, (1, 0)) dl = jax.numpy.pad(dl, (1, 0))
du = jax.numpy.pad(du, (0, 1)) du = jax.numpy.pad(du, (0, 1))
# if b is a vector, broadcast it to be a matrix # jax also requires b to be a matrix; reshape it to be a column vector if necessary
b_is_vec = len(b.shape) == 1 b_is_vec = len(b.shape) == 1
if b_is_vec: if b_is_vec:
b = jax.numpy.expand_dims(b, -1) b = jax.numpy.expand_dims(b, -1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论