Unverified 提交 e7dec4d9 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Fix solve_triangular output when overwrite_b=True (#1235)

* Fix bug in solve_triangular when `overwrite_b = True` * Add regression test
上级 5d4e9e07
...@@ -124,20 +124,26 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -124,20 +124,26 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
# This will only copy if A is not already fortran contiguous
A_f = np.asfortranarray(A)
if overwrite_b: if overwrite_b:
B_copy = B if B_is_1d:
B_copy = np.expand_dims(B, -1)
else:
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
B_copy = np.asfortranarray(B)
else: else:
if B_is_1d: if B_is_1d:
# _copy_to_fortran_order does nothing with vectors B_copy = np.copy(np.expand_dims(B, -1))
B_copy = np.copy(B)
else: else:
B_copy = _copy_to_fortran_order(B) B_copy = _copy_to_fortran_order(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
...@@ -155,7 +161,7 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -155,7 +161,7 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
DIAG, DIAG,
N, N,
NRHS, NRHS,
np.asfortranarray(A).T.view(w_type).ctypes, A_f.view(w_type).ctypes,
LDA, LDA,
B_copy.view(w_type).ctypes, B_copy.view(w_type).ctypes,
LDB, LDB,
......
...@@ -10,6 +10,7 @@ from scipy import linalg as scipy_linalg ...@@ -10,6 +10,7 @@ from scipy import linalg as scipy_linalg
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.tensor.slinalg import SolveTriangular
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -130,6 +131,48 @@ def test_solve_triangular_grad(lower, unit_diag, trans): ...@@ -130,6 +131,48 @@ def test_solve_triangular_grad(lower, unit_diag, trans):
) )
@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"])
def test_solve_triangular_overwrite_b_correct(overwrite_b):
# Regression test for issue #1233
rng = np.random.default_rng(utt.fetch_seed())
a_test_py = np.asfortranarray(rng.normal(size=(3, 3)))
a_test_py = np.tril(a_test_py)
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))
# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
a_test_nb = a_test_py.copy(order="F")
b_test_nb = b_test_py.copy(order="F")
op = SolveTriangular(
trans=0,
unit_diagonal=False,
lower=False,
check_finite=True,
b_ndim=2,
overwrite_b=overwrite_b,
)
a_pt = pt.matrix("a", shape=(3, 3))
b_pt = pt.matrix("b", shape=(3, 2))
out = op(a_pt, b_pt)
py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True)
numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA")
x_py = py_fn(a_test_py, b_test_py)
x_nb = numba_fn(a_test_nb, b_test_nb)
np.testing.assert_allclose(
py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb)
)
np.testing.assert_allclose(b_test_py, b_test_nb)
if overwrite_b:
np.testing.assert_allclose(b_test_py, x_py)
np.testing.assert_allclose(b_test_nb, x_nb)
@pytest.mark.parametrize("value", [np.nan, np.inf]) @pytest.mark.parametrize("value", [np.nan, np.inf])
@pytest.mark.filterwarnings( @pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"' 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论