提交 5f6c0103 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test organization: group Solve and Decompositions tests

上级 993c2c64
...@@ -410,381 +410,379 @@ class TestSolves: ...@@ -410,381 +410,379 @@ class TestSolves:
# Can never destroy non-contiguous inputs # Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}")
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") @pytest.mark.parametrize(
@pytest.mark.parametrize( "overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"] )
) @pytest.mark.parametrize(
def test_cholesky(lower: bool, overwrite_a: bool): "b_func, b_shape",
cov = pt.matrix("cov") [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
chol = pt.linalg.cholesky(cov, lower=lower) ids=["b_col_vec", "b_matrix", "b_vec"],
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
fn, res = compare_numba_and_py(
[In(cov, mutable=overwrite_a)],
[chol],
[val],
numba_mode=numba_inplace_mode,
inplace=True,
) )
def test_lu_solve(
self, b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool
):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
op = fn.maker.fgraph.outputs[0].owner.op rng = np.random.default_rng(418)
assert isinstance(op, Cholesky) A_val = rng.normal(size=(5, 5)).astype(floatX)
destroy_map = op.destroy_map b_val = rng.normal(size=b_shape).astype(floatX)
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
# Test F-contiguous input lu_and_piv = pt.linalg.lu_factor(A)
val_f_contig = np.copy(val, order="F") X = pt.linalg.lu_solve(
res_f_contig = fn(val_f_contig) lu_and_piv,
np.testing.assert_allclose(res_f_contig, res) b,
# Should always be destroyable b_ndim=len(b_shape),
assert (val == val_f_contig).all() == (not overwrite_a) trans=trans,
)
# Test C-contiguous input f, res = compare_numba_and_py(
val_c_contig = np.copy(val, order="C") [A, In(b, mutable=overwrite_b)],
res_c_contig = fn(val_c_contig) X,
np.testing.assert_allclose(res_c_contig, res) test_inputs=[A_val, b_val],
# Cannot destroy C-contiguous input inplace=True,
np.testing.assert_allclose(val_c_contig, val) numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test non-contiguous input # Test with F_contiguous inputs
val_not_contig = np.repeat(val, 2, axis=0)[::2] A_val_f_contig = np.copy(A_val, order="F")
res_not_contig = fn(val_not_contig) b_val_f_contig = np.copy(b_val, order="F")
np.testing.assert_allclose(res_not_contig, res) res_f_contig = f(A_val_f_contig, b_val_f_contig)
# Cannot destroy non-contiguous input np.testing.assert_allclose(res_f_contig, res)
np.testing.assert_allclose(val_not_contig, val)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
def test_cholesky_raises_on_nan_input(): if should_destroy:
test_value = rng.random(size=(3, 3)).astype(floatX) assert not all_equal
test_value[0, 0] = np.nan else:
assert all_equal
x = pt.tensor(dtype=floatX, shape=(3, 3)) # Test with C_contiguous inputs
x = x.T.dot(x) A_val_c_contig = np.copy(A_val, order="C")
g = pt.linalg.cholesky(x, check_finite=True) b_val_c_contig = np.copy(b_val, order="C")
f = pytensor.function([x], g, mode="NUMBA") res_c_contig = f(A_val_c_contig, b_val_c_contig)
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): np.testing.assert_allclose(res_c_contig, res)
f(test_value) np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert not (
should_destroy and b_val_c_contig.flags.f_contiguous
) == np.allclose(b_val_c_contig, b_val)
@pytest.mark.parametrize("on_error", ["nan", "raise"]) # Test with non-contiguous inputs
def test_cholesky_raise_on(on_error): A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
test_value = rng.random(size=(3, 3)).astype(floatX) b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
x = pt.tensor(dtype=floatX, shape=(3, 3)) # Can never destroy non-contiguous inputs
g = pt.linalg.cholesky(x, on_error=on_error) np.testing.assert_allclose(b_val_not_contig, b_val)
f = pytensor.function([x], g, mode="NUMBA")
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
class TestDecompositions:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
)
def test_cholesky(self, lower: bool, overwrite_a: bool):
cov = pt.matrix("cov")
chol = pt.linalg.cholesky(cov, lower=lower)
def test_block_diag(): x = np.array([0.1, 0.2, 0.3]).astype(floatX)
A = pt.matrix("A") val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
A_val = np.random.normal(size=(5, 5)).astype(floatX) fn, res = compare_numba_and_py(
B_val = np.random.normal(size=(3, 3)).astype(floatX) [In(cov, mutable=overwrite_a)],
C_val = np.random.normal(size=(2, 2)).astype(floatX) [chol],
D_val = np.random.normal(size=(4, 4)).astype(floatX) [val],
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val]) numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, Cholesky)
destroy_map = op.destroy_map
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"]) # Test F-contiguous input
def test_pivot_to_permutation(inverse): val_f_contig = np.copy(val, order="F")
from pytensor.tensor.slinalg import pivot_to_permutation res_f_contig = fn(val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
# Should always be destroyable
assert (val == val_f_contig).all() == (not overwrite_a)
rng = np.random.default_rng(123) # Test C-contiguous input
A = rng.normal(size=(5, 5)).astype(floatX) val_c_contig = np.copy(val, order="C")
res_c_contig = fn(val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, val)
perm_pt = pt.vector("p", dtype="int32") # Test non-contiguous input
piv_pt = pivot_to_permutation(perm_pt, inverse=inverse) val_not_contig = np.repeat(val, 2, axis=0)[::2]
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA") res_not_contig = fn(val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, val)
_, piv = scipy.linalg.lu_factor(A) def test_cholesky_raises_on_nan_input(self):
test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan
if inverse: x = pt.tensor(dtype=floatX, shape=(3, 3))
p = np.arange(len(piv)) x = x.T.dot(x)
for i in range(len(piv)): g = pt.linalg.cholesky(x, check_finite=True)
p[i], p[piv[i]] = p[piv[i]], p[i] f = pytensor.function([x], g, mode="NUMBA")
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)
@pytest.mark.parametrize( @pytest.mark.parametrize("on_error", ["nan", "raise"])
"permute_l, p_indices", def test_cholesky_raise_on(self, on_error):
[(True, False), (False, True), (False, False)], test_value = rng.random(size=(3, 3)).astype(floatX)
ids=["PL", "p_indices", "P"],
) x = pt.tensor(dtype=floatX, shape=(3, 3))
@pytest.mark.parametrize( g = pt.linalg.cholesky(x, on_error=on_error)
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] f = pytensor.function([x], g, mode="NUMBA")
)
def test_lu(permute_l, p_indices, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
lu_outputs = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices) if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError,
match=r"Input to cholesky is not positive definite",
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
fn, res = compare_numba_and_py( @pytest.mark.parametrize(
[In(A, mutable=overwrite_a)], "permute_l, p_indices",
lu_outputs, [(True, False), (False, True), (False, False)],
[A_val], ids=["PL", "p_indices", "P"],
numba_mode=numba_inplace_mode, )
inplace=True, @pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
) )
def test_lu(self, permute_l, p_indices, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
op = fn.maker.fgraph.outputs[0].owner.op lu_outputs = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
assert isinstance(op, LU)
destroy_map = op.destroy_map fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
lu_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
if overwrite_a and permute_l: op = fn.maker.fgraph.outputs[0].owner.op
assert destroy_map == {0: [0]} assert isinstance(op, LU)
elif overwrite_a:
assert destroy_map == {1: [0]}
else:
assert destroy_map == {}
# Test F-contiguous input destroy_map = op.destroy_map
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True): if overwrite_a and permute_l:
np.testing.assert_allclose(x, x_f_contig) assert destroy_map == {0: [0]}
elif overwrite_a:
assert destroy_map == {1: [0]}
else:
assert destroy_map == {}
# Should always be destroyable # Test F-contiguous input
assert (A_val == val_f_contig).all() == (not overwrite_a) val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
# Test C-contiguous input for x, x_f_contig in zip(res, res_f_contig, strict=True):
val_c_contig = np.copy(A_val, order="C") np.testing.assert_allclose(x, x_f_contig)
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input # Should always be destroyable
np.testing.assert_allclose(val_c_contig, A_val) assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test non-contiguous input # Test C-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2] val_c_contig = np.copy(A_val, order="C")
res_not_contig = fn(val_not_contig) res_c_contig = fn(val_c_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True): for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig) np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy non-contiguous input # Cannot destroy C-contiguous input
np.testing.assert_allclose(val_not_contig, A_val) np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
@pytest.mark.parametrize( # Cannot destroy non-contiguous input
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] np.testing.assert_allclose(val_not_contig, A_val)
)
def test_lu_factor(overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor("A", shape=shape, dtype=config.floatX) @pytest.mark.parametrize(
A_val = rng.normal(size=shape).astype(config.floatX) "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu_factor(self, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
LU, piv = pt.linalg.lu_factor(A) A = pt.tensor("A", shape=shape, dtype=config.floatX)
A_val = rng.normal(size=shape).astype(config.floatX)
fn, res = compare_numba_and_py( LU, piv = pt.linalg.lu_factor(A)
[In(A, mutable=overwrite_a)],
[LU, piv],
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op fn, res = compare_numba_and_py(
assert isinstance(op, LUFactor) [In(A, mutable=overwrite_a)],
[LU, piv],
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
if overwrite_a: op = fn.maker.fgraph.outputs[0].owner.op
assert op.destroy_map == {1: [0]} assert isinstance(op, LUFactor)
# Test F-contiguous input if overwrite_a:
val_f_contig = np.copy(A_val, order="F") assert op.destroy_map == {1: [0]}
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True): # Test F-contiguous input
np.testing.assert_allclose(x, x_f_contig) val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
# Should always be destroyable for x, x_f_contig in zip(res, res_f_contig, strict=True):
assert (A_val == val_f_contig).all() == (not overwrite_a) np.testing.assert_allclose(x, x_f_contig)
# Test C-contiguous input # Should always be destroyable
val_c_contig = np.copy(A_val, order="C") assert (A_val == val_f_contig).all() == (not overwrite_a)
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input # Test C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val) val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Test non-contiguous input # Cannot destroy C-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2] np.testing.assert_allclose(val_c_contig, A_val)
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input # Test non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val) val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}") @pytest.mark.parametrize(
@pytest.mark.parametrize( "mode, pivoting",
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"] [("economic", False), ("full", True), ("r", False), ("raw", True)],
) ids=["economic", "full_pivot", "r", "raw_pivot"],
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = rng.normal(size=(5, 5)).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
lu_and_piv = pt.linalg.lu_factor(A)
X = pt.linalg.lu_solve(
lu_and_piv,
b,
b_ndim=len(b_shape),
trans=trans,
) )
@pytest.mark.parametrize(
f, res = compare_numba_and_py( "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
) )
def test_qr(self, mode, pivoting, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
# Test with F_contiguous inputs qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
if should_destroy: fn, res = compare_numba_and_py(
assert not all_equal [In(A, mutable=overwrite_a)],
else: qr_outputs,
assert all_equal [A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
# Test with C_contiguous inputs op = fn.maker.fgraph.outputs[0].owner.op
A_val_c_contig = np.copy(A_val, order="C") assert isinstance(op, QR)
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res) destroy_map = op.destroy_map
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable if overwrite_a:
assert not (should_destroy and b_val_c_contig.flags.f_contiguous) == np.allclose( assert destroy_map == {0: [0]}
b_val_c_contig, b_val else:
) assert destroy_map == {}
# Test with non-contiguous inputs # Test F-contiguous input
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2] val_f_contig = np.copy(A_val, order="F")
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2] res_f_contig = fn(val_f_contig)
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Can never destroy non-contiguous inputs for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
@pytest.mark.parametrize( # Test C-contiguous input
"mode, pivoting", val_c_contig = np.copy(A_val, order="C")
[("economic", False), ("full", True), ("r", False), ("raw", True)], res_c_contig = fn(val_c_contig)
ids=["economic", "full_pivot", "r", "raw_pivot"], for x, x_c_contig in zip(res, res_c_contig, strict=True):
) np.testing.assert_allclose(x, x_c_contig)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_qr(mode, pivoting, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting) # Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
fn, res = compare_numba_and_py( # Test non-contiguous input
[In(A, mutable=overwrite_a)], val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
qr_outputs, res_not_contig = fn(val_not_contig)
[A_val], for x, x_not_contig in zip(res, res_not_contig, strict=True):
numba_mode=numba_inplace_mode, np.testing.assert_allclose(x, x_not_contig)
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op # Cannot destroy non-contiguous input
assert isinstance(op, QR) np.testing.assert_allclose(val_not_contig, A_val)
destroy_map = op.destroy_map
if overwrite_a: def test_block_diag():
assert destroy_map == {0: [0]} A = pt.matrix("A")
else: B = pt.matrix("B")
assert destroy_map == {} C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
# Test F-contiguous input A_val = np.random.normal(size=(5, 5)).astype(floatX)
val_f_contig = np.copy(A_val, order="F") B_val = np.random.normal(size=(3, 3)).astype(floatX)
res_f_contig = fn(val_f_contig) C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable @pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
assert (A_val == val_f_contig).all() == (not overwrite_a) def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
# Test C-contiguous input rng = np.random.default_rng(123)
val_c_contig = np.copy(A_val, order="C") A = rng.normal(size=(5, 5)).astype(floatX)
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input perm_pt = pt.vector("p", dtype="int32")
np.testing.assert_allclose(val_c_contig, A_val) piv_pt = pivot_to_permutation(perm_pt, inverse=inverse)
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA")
# Test non-contiguous input _, piv = scipy.linalg.lu_factor(A)
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input if inverse:
np.testing.assert_allclose(val_not_contig, A_val) p = np.arange(len(piv))
for i in range(len(piv)):
p[i], p[piv[i]] = p[piv[i]], p[i]
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论