提交 61380247 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Cholesky: Allow inplace on C_contiguous inputs

上级 ced9939d
...@@ -29,21 +29,27 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -29,21 +29,27 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
numba_potrf = _LAPACK().numba_xpotrf(dtype) numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True): def impl(A, lower=False, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
if A.shape[-2] != _N: if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square") raise linalg.LinAlgError("Last 2 dimensions of A must be square")
UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) transposed = False
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
if overwrite_a and A.flags.f_contiguous: if overwrite_a and A.flags.f_contiguous:
A_copy = A A_copy = A
elif overwrite_a and A.flags.c_contiguous:
# We can work on the transpose of A directly
A_copy = A.T
transposed = True
lower = not lower
else: else:
A_copy = _copy_to_fortran_order(A) A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_potrf( numba_potrf(
UPLO, UPLO,
N, N,
...@@ -61,6 +67,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -61,6 +67,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N): for i in range(j + 1, _N):
A_copy[i, j] = 0.0 A_copy[i, j] = 0.0
return A_copy, int_ptr_to_val(INFO) info_int = int_ptr_to_val(INFO)
if transposed:
return A_copy.T, info_int
return A_copy, info_int
return impl return impl
...@@ -551,8 +551,8 @@ class TestDecompositions: ...@@ -551,8 +551,8 @@ class TestDecompositions:
val_c_contig = np.copy(val, order="C") val_c_contig = np.copy(val, order="C")
res_c_contig = fn(val_c_contig) res_c_contig = fn(val_c_contig)
np.testing.assert_allclose(res_c_contig, res) np.testing.assert_allclose(res_c_contig, res)
# Cannot destroy C-contiguous input # Should always be destroyable
np.testing.assert_allclose(val_c_contig, val) assert (val == val_c_contig).all() == (not overwrite_a)
# Test non-contiguous input # Test non-contiguous input
val_not_contig = np.repeat(val, 2, axis=0)[::2] val_not_contig = np.repeat(val, 2, axis=0)[::2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论