提交 8b8ba028 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba tridiagonal: avoid inference error when casting inputs

Numba doesn't infer the right type based on the static tuple, but does so with separate boolean variables
上级 eba75f6f
...@@ -356,8 +356,10 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -356,8 +356,10 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_du = op.overwrite_du overwrite_du = op.overwrite_du
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
must_cast_inputs = tuple(inp.type.numpy_dtype != out_dtype for inp in node.inputs) cast_inputs = (cast_dl, cast_d, cast_du) = tuple(
if any(must_cast_inputs) and config.compiler_verbose: inp.type.numpy_dtype != out_dtype for inp in node.inputs
)
if any(cast_inputs) and config.compiler_verbose:
print("LUFactorTridiagonal requires casting at least one input") # noqa: T201 print("LUFactorTridiagonal requires casting at least one input") # noqa: T201
@numba_basic.numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
...@@ -371,11 +373,11 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -371,11 +373,11 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
np.zeros(d.shape, dtype="int32"), np.zeros(d.shape, dtype="int32"),
) )
if must_cast_inputs[0]: if cast_d:
d = d.astype(out_dtype) d = d.astype(out_dtype)
if must_cast_inputs[1]: if cast_dl:
dl = dl.astype(out_dtype) dl = dl.astype(out_dtype)
if must_cast_inputs[2]: if cast_du:
du = du.astype(out_dtype) du = du.astype(out_dtype)
dl, d, du, du2, ipiv, _ = _gttrf( dl, d, du, du2, ipiv, _ = _gttrf(
dl, dl,
...@@ -402,7 +404,7 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -402,7 +404,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
transposed = op.transposed transposed = op.transposed
must_cast_inputs = tuple( must_cast_inputs = (cast_dl, cast_d, cast_du, cast_du2, cast_ipiv, cast_b) = tuple(
inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype) inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype)
for i, inp in enumerate(node.inputs) for i, inp in enumerate(node.inputs)
) )
...@@ -417,17 +419,17 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -417,17 +419,17 @@ def numba_funcify_SolveLUFactorTridiagonal(
else: else:
return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype) return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype)
if must_cast_inputs[0]: if cast_dl:
dl = dl.astype(out_dtype) dl = dl.astype(out_dtype)
if must_cast_inputs[1]: if cast_d:
d = d.astype(out_dtype) d = d.astype(out_dtype)
if must_cast_inputs[2]: if cast_du:
du = du.astype(out_dtype) du = du.astype(out_dtype)
if must_cast_inputs[3]: if cast_du2:
du2 = du2.astype(out_dtype) du2 = du2.astype(out_dtype)
if must_cast_inputs[4]: if cast_ipiv:
ipiv = ipiv.astype("int32") ipiv = ipiv.astype(np.int32)
if must_cast_inputs[5]: if cast_b:
b = b.astype(out_dtype) b = b.astype(out_dtype)
x, _ = _gttrs( x, _ = _gttrs(
dl, dl,
......
...@@ -112,3 +112,34 @@ def test_tridiagonal_lu_solve(b_ndim, transposed, inplace): ...@@ -112,3 +112,34 @@ def test_tridiagonal_lu_solve(b_ndim, transposed, inplace):
assert (res_non_contig == res).all() assert (res_non_contig == res).all()
# b must be copied when not contiguous so it can't be inplaced # b must be copied when not contiguous so it can't be inplaced
assert (b_test == b_test_non_contig).all() assert (b_test == b_test_non_contig).all()
def test_cast_needed():
dl = pt.vector("dl", shape=(4,), dtype="int16")
d = pt.vector("d", shape=(5,), dtype="float32")
du = pt.vector("du", shape=(4,), dtype="float64")
b = pt.vector("b", shape=(5,), dtype="float32")
lu_factor_outs = LUFactorTridiagonal()(dl, d, du)
for i, out in enumerate(lu_factor_outs):
if i == 4:
assert out.type.dtype == "int32" # ipiv is int32
else:
assert out.type.dtype == "float64"
lu_solve_out = SolveLUFactorTridiagonal(b_ndim=1, transposed=False)(
*lu_factor_outs, b
)
assert lu_solve_out.type.dtype == "float64"
compare_numba_and_py(
[dl, d, du, b],
lu_solve_out,
test_inputs=[
np.array([1, 2, 3, 4], dtype="int16"),
np.array([1, 2, 3, 4, 5], dtype="float32"),
np.array([1, 2, 3, 4], dtype="float64"),
np.array([1, 2, 3, 4, 5], dtype="float32"),
],
eval_obj_mode=False,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论