提交 0f2afc48 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make test_ExtractDiag_exhaustive less exhaustive

上级 02b7738d
...@@ -15,7 +15,6 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast ...@@ -15,7 +15,6 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
pytest.importorskip("numba") pytest.importorskip("numba")
from pytensor.link.numba.dispatch import numba_funcify
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -274,28 +273,19 @@ def test_ExtractDiag(val, offset): ...@@ -274,28 +273,19 @@ def test_ExtractDiag(val, offset):
) )
@pytest.mark.parametrize("k", range(-5, 4)) @pytest.mark.parametrize("k", (-5, -1, 0, 1, 4))
@pytest.mark.parametrize( @pytest.mark.parametrize("axis1, axis2", ((0, 1), (0, 3), (1, 2), (2, 1), (2, 3)))
"axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)) def test_ExtractDiag_exhaustive(k, axis1, axis2):
)
@pytest.mark.parametrize("reverse_axis", (False, True))
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
from pytensor.link.numba.dispatch.basic import numba_njit
if reverse_axis:
axis1, axis2 = axis2, axis1
x = pt.tensor4("x") x = pt.tensor4("x")
x_shape = (2, 3, 4, 5) x_shape = (2, 3, 4, 5)
x_test = np.arange(np.prod(x_shape)).reshape(x_shape) x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
out = pt.diagonal(x, k, axis1, axis2) out = pt.diagonal(x, k, axis1, axis2)
numba_fn = numba_funcify(out.owner.op, out.owner)
@numba_njit(no_cpython_wrapper=False) compare_numba_and_py(
def wrap(x): [x],
return numba_fn(x) out,
[x_test],
np.testing.assert_allclose(wrap(x_test), np.diagonal(x_test, k, axis1, axis2)) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论