提交 5f6c0103 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test organization: group Solve and Decompositions tests

上级 993c2c64
......@@ -410,381 +410,379 @@ class TestSolves:
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
)
def test_cholesky(lower: bool, overwrite_a: bool):
cov = pt.matrix("cov")
chol = pt.linalg.cholesky(cov, lower=lower)
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
fn, res = compare_numba_and_py(
[In(cov, mutable=overwrite_a)],
[chol],
[val],
numba_mode=numba_inplace_mode,
inplace=True,
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}")
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_lu_solve(
self, b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool
):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, Cholesky)
destroy_map = op.destroy_map
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
rng = np.random.default_rng(418)
A_val = rng.normal(size=(5, 5)).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
# Test F-contiguous input
val_f_contig = np.copy(val, order="F")
res_f_contig = fn(val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
# Should always be destroyable
assert (val == val_f_contig).all() == (not overwrite_a)
lu_and_piv = pt.linalg.lu_factor(A)
X = pt.linalg.lu_solve(
lu_and_piv,
b,
b_ndim=len(b_shape),
trans=trans,
)
# Test C-contiguous input
val_c_contig = np.copy(val, order="C")
res_c_contig = fn(val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, val)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test non-contiguous input
val_not_contig = np.repeat(val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, val)
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
def test_cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan
if should_destroy:
assert not all_equal
else:
assert all_equal
x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert not (
should_destroy and b_val_c_contig.flags.f_contiguous
) == np.allclose(b_val_c_contig, b_val)
@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(floatX)
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
class TestDecompositions:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
)
def test_cholesky(self, lower: bool, overwrite_a: bool):
cov = pt.matrix("cov")
chol = pt.linalg.cholesky(cov, lower=lower)
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
fn, res = compare_numba_and_py(
[In(cov, mutable=overwrite_a)],
[chol],
[val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, Cholesky)
destroy_map = op.destroy_map
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
# Test F-contiguous input
val_f_contig = np.copy(val, order="F")
res_f_contig = fn(val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
# Should always be destroyable
assert (val == val_f_contig).all() == (not overwrite_a)
rng = np.random.default_rng(123)
A = rng.normal(size=(5, 5)).astype(floatX)
# Test C-contiguous input
val_c_contig = np.copy(val, order="C")
res_c_contig = fn(val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, val)
perm_pt = pt.vector("p", dtype="int32")
piv_pt = pivot_to_permutation(perm_pt, inverse=inverse)
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA")
# Test non-contiguous input
val_not_contig = np.repeat(val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, val)
_, piv = scipy.linalg.lu_factor(A)
def test_cholesky_raises_on_nan_input(self):
test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan
if inverse:
p = np.arange(len(piv))
for i in range(len(piv)):
p[i], p[piv[i]] = p[piv[i]], p[i]
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)
@pytest.mark.parametrize(
"permute_l, p_indices",
[(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu(permute_l, p_indices, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_cholesky_raise_on(self, on_error):
test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
lu_outputs = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError,
match=r"Input to cholesky is not positive definite",
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
lu_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
@pytest.mark.parametrize(
"permute_l, p_indices",
[(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu(self, permute_l, p_indices, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LU)
lu_outputs = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
destroy_map = op.destroy_map
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
lu_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
if overwrite_a and permute_l:
assert destroy_map == {0: [0]}
elif overwrite_a:
assert destroy_map == {1: [0]}
else:
assert destroy_map == {}
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LU)
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
destroy_map = op.destroy_map
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
if overwrite_a and permute_l:
assert destroy_map == {0: [0]}
elif overwrite_a:
assert destroy_map == {1: [0]}
else:
assert destroy_map == {}
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu_factor(overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
A = pt.tensor("A", shape=shape, dtype=config.floatX)
A_val = rng.normal(size=shape).astype(config.floatX)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu_factor(self, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
LU, piv = pt.linalg.lu_factor(A)
A = pt.tensor("A", shape=shape, dtype=config.floatX)
A_val = rng.normal(size=shape).astype(config.floatX)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
[LU, piv],
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
LU, piv = pt.linalg.lu_factor(A)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LUFactor)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
[LU, piv],
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
if overwrite_a:
assert op.destroy_map == {1: [0]}
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LUFactor)
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
if overwrite_a:
assert op.destroy_map == {1: [0]}
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}")
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = rng.normal(size=(5, 5)).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
lu_and_piv = pt.linalg.lu_factor(A)
X = pt.linalg.lu_solve(
lu_and_piv,
b,
b_ndim=len(b_shape),
trans=trans,
@pytest.mark.parametrize(
"mode, pivoting",
[("economic", False), ("full", True), ("r", False), ("raw", True)],
ids=["economic", "full_pivot", "r", "raw_pivot"],
)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_qr(self, mode, pivoting, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
if should_destroy:
assert not all_equal
else:
assert all_equal
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
qr_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, QR)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
destroy_map = op.destroy_map
# b c_contiguous vectors are also f_contiguous and destroyable
assert not (should_destroy and b_val_c_contig.flags.f_contiguous) == np.allclose(
b_val_c_contig, b_val
)
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
@pytest.mark.parametrize(
"mode, pivoting",
[("economic", False), ("full", True), ("r", False), ("raw", True)],
ids=["economic", "full_pivot", "r", "raw_pivot"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_qr(mode, pivoting, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
qr_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, QR)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
destroy_map = op.destroy_map
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
rng = np.random.default_rng(123)
A = rng.normal(size=(5, 5)).astype(floatX)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
perm_pt = pt.vector("p", dtype="int32")
piv_pt = pivot_to_permutation(perm_pt, inverse=inverse)
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA")
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
_, piv = scipy.linalg.lu_factor(A)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
if inverse:
p = np.arange(len(piv))
for i in range(len(piv)):
p[i], p[piv[i]] = p[piv[i]], p[i]
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论