提交 e25e8a2a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix unauthorized inplace update of vector B in numba solve_triangular

上级 cc8c4992
...@@ -126,13 +126,17 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -126,13 +126,17 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
if not overwrite_b: if overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B B_copy = B
else:
if B_is_1d:
# _copy_to_fortran_order does nothing with vectors
B_copy = np.copy(B)
else:
B_copy = _copy_to_fortran_order(B)
if B_is_1d: if B_is_1d:
B_copy = np.expand_dims(B, -1) B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
......
...@@ -79,9 +79,9 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl ...@@ -79,9 +79,9 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
A_val = A_val + np.random.normal(size=(5, 5)) * 1j A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b_val = b_val + np.random.normal(size=b_shape) * 1j b_val = b_val + np.random.normal(size=b_shape) * 1j
X_np = f(A_func(A_val.copy()), b_val.copy()) X_np = f(A_func(A_val), b_val)
test_input = transpose_func(A_func(A_val.copy()), trans) test_input = transpose_func(A_func(A_val), trans)
ATOL = 1e-8 if floatX.endswith("64") else 1e-4 ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4 RTOL = 1e-8 if floatX.endswith("64") else 1e-4
...@@ -92,7 +92,7 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl ...@@ -92,7 +92,7 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
compare_numba_and_py( compare_numba_and_py(
compiled_fgraph.inputs, compiled_fgraph.inputs,
compiled_fgraph.outputs, compiled_fgraph.outputs,
[A_func(A_val.copy()), b_val.copy()], [A_func(A_val), b_val],
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论