提交 eba75f6f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Tridiagonal Solve: Fix dtype inference

Scipy helper doesn't have special handling for ipiv int32 variable, and always assumes it must be cast to a float64
上级 f4fb4833
...@@ -146,9 +146,8 @@ class SolveLUFactorTridiagonal(Op): ...@@ -146,9 +146,8 @@ class SolveLUFactorTridiagonal(Op):
n = nb n = nb
dummy_arrays = [ dummy_arrays = [
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv) np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, b)
] ]
# Seems to always be float64?
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
if self.b_ndim == 1: if self.b_ndim == 1:
output_shape = (n,) output_shape = (n,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论