提交 e25e8a2a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix unauthorized inplace update of vector B in numba solve_triangular

上级 cc8c4992
......@@ -126,13 +126,17 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b:
B_copy = B
else:
if B_is_1d:
# _copy_to_fortran_order does nothing with vectors
B_copy = np.copy(B)
else:
B_copy = _copy_to_fortran_order(B)
if B_is_1d:
B_copy = np.expand_dims(B, -1)
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
......
......@@ -79,9 +79,9 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b_val = b_val + np.random.normal(size=b_shape) * 1j
X_np = f(A_func(A_val.copy()), b_val.copy())
X_np = f(A_func(A_val), b_val)
test_input = transpose_func(A_func(A_val.copy()), trans)
test_input = transpose_func(A_func(A_val), trans)
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
......@@ -92,7 +92,7 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
compare_numba_and_py(
compiled_fgraph.inputs,
compiled_fgraph.outputs,
[A_func(A_val.copy()), b_val.copy()],
[A_func(A_val), b_val],
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论