提交 0f2afc48 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make test_ExtractDiag_exhaustive less exhaustive

上级 02b7738d
......@@ -15,7 +15,6 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
pytest.importorskip("numba")
from pytensor.link.numba.dispatch import numba_funcify
rng = np.random.default_rng(42849)
......@@ -274,28 +273,19 @@ def test_ExtractDiag(val, offset):
)
@pytest.mark.parametrize("k", range(-5, 4))
@pytest.mark.parametrize(
"axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3))
)
@pytest.mark.parametrize("reverse_axis", (False, True))
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
from pytensor.link.numba.dispatch.basic import numba_njit
if reverse_axis:
axis1, axis2 = axis2, axis1
@pytest.mark.parametrize("k", (-5, -1, 0, 1, 4))
@pytest.mark.parametrize("axis1, axis2", ((0, 1), (0, 3), (1, 2), (2, 1), (2, 3)))
def test_ExtractDiag_exhaustive(k, axis1, axis2):
x = pt.tensor4("x")
x_shape = (2, 3, 4, 5)
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
out = pt.diagonal(x, k, axis1, axis2)
numba_fn = numba_funcify(out.owner.op, out.owner)
@numba_njit(no_cpython_wrapper=False)
def wrap(x):
return numba_fn(x)
np.testing.assert_allclose(wrap(x_test), np.diagonal(x_test, k, axis1, axis2))
compare_numba_and_py(
[x],
out,
[x_test],
)
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论