提交 8b8ba028 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba tridiagonal: avoid inference error when casting inputs

Numba doesn't infer the right type based on the static tuple, but does so with separate boolean variables
上级 eba75f6f
......@@ -356,8 +356,10 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_du = op.overwrite_du
out_dtype = node.outputs[1].type.numpy_dtype
must_cast_inputs = tuple(inp.type.numpy_dtype != out_dtype for inp in node.inputs)
if any(must_cast_inputs) and config.compiler_verbose:
cast_inputs = (cast_dl, cast_d, cast_du) = tuple(
inp.type.numpy_dtype != out_dtype for inp in node.inputs
)
if any(cast_inputs) and config.compiler_verbose:
print("LUFactorTridiagonal requires casting at least one input") # noqa: T201
@numba_basic.numba_njit(cache=False)
......@@ -371,11 +373,11 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
np.zeros(d.shape, dtype="int32"),
)
if must_cast_inputs[0]:
if cast_d:
d = d.astype(out_dtype)
if must_cast_inputs[1]:
if cast_dl:
dl = dl.astype(out_dtype)
if must_cast_inputs[2]:
if cast_du:
du = du.astype(out_dtype)
dl, d, du, du2, ipiv, _ = _gttrf(
dl,
......@@ -402,7 +404,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b = op.overwrite_b
transposed = op.transposed
must_cast_inputs = tuple(
must_cast_inputs = (cast_dl, cast_d, cast_du, cast_du2, cast_ipiv, cast_b) = tuple(
inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype)
for i, inp in enumerate(node.inputs)
)
......@@ -417,17 +419,17 @@ def numba_funcify_SolveLUFactorTridiagonal(
else:
return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype)
if must_cast_inputs[0]:
if cast_dl:
dl = dl.astype(out_dtype)
if must_cast_inputs[1]:
if cast_d:
d = d.astype(out_dtype)
if must_cast_inputs[2]:
if cast_du:
du = du.astype(out_dtype)
if must_cast_inputs[3]:
if cast_du2:
du2 = du2.astype(out_dtype)
if must_cast_inputs[4]:
ipiv = ipiv.astype("int32")
if must_cast_inputs[5]:
if cast_ipiv:
ipiv = ipiv.astype(np.int32)
if cast_b:
b = b.astype(out_dtype)
x, _ = _gttrs(
dl,
......
......@@ -112,3 +112,34 @@ def test_tridiagonal_lu_solve(b_ndim, transposed, inplace):
assert (res_non_contig == res).all()
# b must be copied when not contiguous so it can't be inplaced
assert (b_test == b_test_non_contig).all()
def test_cast_needed():
dl = pt.vector("dl", shape=(4,), dtype="int16")
d = pt.vector("d", shape=(5,), dtype="float32")
du = pt.vector("du", shape=(4,), dtype="float64")
b = pt.vector("b", shape=(5,), dtype="float32")
lu_factor_outs = LUFactorTridiagonal()(dl, d, du)
for i, out in enumerate(lu_factor_outs):
if i == 4:
assert out.type.dtype == "int32" # ipiv is int32
else:
assert out.type.dtype == "float64"
lu_solve_out = SolveLUFactorTridiagonal(b_ndim=1, transposed=False)(
*lu_factor_outs, b
)
assert lu_solve_out.type.dtype == "float64"
compare_numba_and_py(
[dl, d, du, b],
lu_solve_out,
test_inputs=[
np.array([1, 2, 3, 4], dtype="int16"),
np.array([1, 2, 3, 4, 5], dtype="float32"),
np.array([1, 2, 3, 4], dtype="float64"),
np.array([1, 2, 3, 4, 5], dtype="float32"),
],
eval_obj_mode=False,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论