提交 f6986e40 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba linalg: Handle empty inputs

上级 ffd999c8
...@@ -362,6 +362,15 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -362,6 +362,15 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
@numba_basic.numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du): def lu_factor_tridiagonal(dl, d, du):
if d.size == 0:
return (
np.zeros(dl.shape, dtype=out_dtype),
np.zeros(d.shape, dtype=out_dtype),
np.zeros(du.shape, dtype=out_dtype),
np.zeros(d.shape, dtype=out_dtype),
np.zeros(d.shape, dtype="int32"),
)
if must_cast_inputs[0]: if must_cast_inputs[0]:
d = d.astype(out_dtype) d = d.astype(out_dtype)
if must_cast_inputs[1]: if must_cast_inputs[1]:
...@@ -389,6 +398,7 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -389,6 +398,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
return generate_fallback_impl(op, node=node) return generate_fallback_impl(op, node=node)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
b_ndim = op.b_ndim
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
transposed = op.transposed transposed = op.transposed
...@@ -401,6 +411,12 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -401,6 +411,12 @@ def numba_funcify_SolveLUFactorTridiagonal(
@numba_basic.numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b): def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
if d.size == 0:
if b_ndim == 1:
return np.zeros(d.shape, dtype=out_dtype)
else:
return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype)
if must_cast_inputs[0]: if must_cast_inputs[0]:
dl = dl.astype(out_dtype) dl = dl.astype(out_dtype)
if must_cast_inputs[1]: if must_cast_inputs[1]:
......
...@@ -74,6 +74,9 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -74,6 +74,9 @@ def numba_funcify_Cholesky(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def cholesky(a): def cholesky(a):
if a.size == 0:
return np.zeros(a.shape, dtype=out_dtype)
if discrete_inp: if discrete_inp:
a = a.astype(out_dtype) a = a.astype(out_dtype)
elif check_finite: elif check_finite:
...@@ -114,7 +117,8 @@ def pivot_to_permutation(op, node, **kwargs): ...@@ -114,7 +117,8 @@ def pivot_to_permutation(op, node, **kwargs):
return np.argsort(p_inv) return np.argsort(p_inv)
return numba_pivot_to_permutation cache_key = 1
return numba_pivot_to_permutation, cache_key
@numba_funcify.register(LU) @numba_funcify.register(LU)
...@@ -134,6 +138,18 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -134,6 +138,18 @@ def numba_funcify_LU(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def lu(a): def lu(a):
if a.size == 0:
L = np.zeros(a.shape, dtype=a.dtype)
U = np.zeros(a.shape, dtype=a.dtype)
if permute_l:
return L, U
elif p_indices:
P = np.zeros(a.shape[0], dtype="int32")
return P, L, U
else:
P = np.zeros(a.shape, dtype=a.dtype)
return P, L, U
if discrete_inp: if discrete_inp:
a = a.astype(out_dtype) a = a.astype(out_dtype)
elif check_finite: elif check_finite:
...@@ -187,6 +203,12 @@ def numba_funcify_LUFactor(op, node, **kwargs): ...@@ -187,6 +203,12 @@ def numba_funcify_LUFactor(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def lu_factor(a): def lu_factor(a):
if a.size == 0:
return (
np.zeros(a.shape, dtype=out_dtype),
np.zeros(a.shape[0], dtype="int32"),
)
if discrete_inp: if discrete_inp:
a = a.astype(out_dtype) a = a.astype(out_dtype)
elif check_finite: elif check_finite:
...@@ -226,7 +248,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -226,7 +248,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
@numba_funcify.register(Solve) @numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs) A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c": if A_dtype.kind == "c" or b_dtype.kind == "c":
...@@ -269,6 +291,9 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -269,6 +291,9 @@ def numba_funcify_Solve(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def solve(a, b): def solve(a, b):
if b.size == 0:
return np.zeros(b.shape, dtype=out_dtype)
if must_cast_A: if must_cast_A:
a = a.astype(out_dtype) a = a.astype(out_dtype)
if must_cast_B: if must_cast_B:
...@@ -297,7 +322,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -297,7 +322,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
b_ndim = op.b_ndim b_ndim = op.b_ndim
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs) A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c": if A_dtype.kind == "c" or b_dtype.kind == "c":
...@@ -311,6 +336,8 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -311,6 +336,8 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def solve_triangular(a, b): def solve_triangular(a, b):
if b.size == 0:
return np.zeros(b.shape, dtype=out_dtype)
if must_cast_A: if must_cast_A:
a = a.astype(out_dtype) a = a.astype(out_dtype)
if must_cast_B: if must_cast_B:
...@@ -360,6 +387,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -360,6 +387,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def cho_solve(c, b): def cho_solve(c, b):
if b.size == 0:
return np.zeros(b.shape, dtype=out_dtype)
if must_cast_c: if must_cast_c:
c = c.astype(out_dtype) c = c.astype(out_dtype)
if check_finite: if check_finite:
......
...@@ -16,6 +16,13 @@ from pytensor.tensor.slinalg import ( ...@@ -16,6 +16,13 @@ from pytensor.tensor.slinalg import (
LUFactor, LUFactor,
Solve, Solve,
SolveTriangular, SolveTriangular,
cho_solve,
cholesky,
lu,
lu_factor,
lu_solve,
solve,
solve_triangular,
) )
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
...@@ -483,6 +490,27 @@ class TestSolves: ...@@ -483,6 +490,27 @@ class TestSolves:
# Can never destroy non-contiguous inputs # Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize(
"solve_op",
[solve, solve_triangular, cho_solve, lu_solve],
ids=lambda x: x.__name__,
)
def test_empty(self, solve_op):
a = pt.matrix("x")
b = pt.vector("b")
if solve_op is cho_solve:
out = solve_op((a, True), b)
elif solve_op is lu_solve:
out = solve_op((a, b.astype("int32")), b)
else:
out = solve_op(a, b)
compare_numba_and_py(
[a, b],
[out],
[np.zeros((0, 0)), np.zeros(0)],
eval_obj_mode=False, # pivot_to_permutation seems to still be jitted despite the monkey patching
)
class TestDecompositions: class TestDecompositions:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
...@@ -750,6 +778,20 @@ class TestDecompositions: ...@@ -750,6 +778,20 @@ class TestDecompositions:
# Cannot destroy non-contiguous input # Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val) np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize(
"decomp_op", (cholesky, lu, lu_factor), ids=lambda x: x.__name__
)
def test_empty(self, decomp_op):
x = pt.matrix("x")
outs = decomp_op(x)
if not isinstance(outs, tuple | list):
outs = [outs]
compare_numba_and_py(
[x],
outs,
[np.zeros((0, 0))],
)
def test_block_diag(): def test_block_diag():
A = pt.matrix("A") A = pt.matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论