提交 61380247 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Cholesky: Allow inplace on C_contiguous inputs

上级 ced9939d
......@@ -29,21 +29,27 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True):
def impl(A, lower=False, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
transposed = False
if overwrite_a and A.flags.f_contiguous:
A_copy = A
elif overwrite_a and A.flags.c_contiguous:
# We can work on the transpose of A directly
A_copy = A.T
transposed = True
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_potrf(
UPLO,
N,
......@@ -61,6 +67,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0
return A_copy, int_ptr_to_val(INFO)
info_int = int_ptr_to_val(INFO)
if transposed:
return A_copy.T, info_int
return A_copy, info_int
return impl
......@@ -551,8 +551,8 @@ class TestDecompositions:
val_c_contig = np.copy(val, order="C")
res_c_contig = fn(val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, val)
# Should always be destroyable
assert (val == val_c_contig).all() == (not overwrite_a)
# Test non-contiguous input
val_not_contig = np.repeat(val, 2, axis=0)[::2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论