提交 b9468e04 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

XFAIL conv tests of Ops without Python implementation

Mark overly specific tests as xfail
上级 9caa886c
......@@ -3,9 +3,11 @@ import pytest
import pytensor
import pytensor.tensor as pt
from pytensor.compile import get_default_mode
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.link.numba import NumbaLinker
from pytensor.tensor.conv import abstract_conv
from pytensor.tensor.conv.abstract_conv import (
AbstractConv2d,
......@@ -757,6 +759,10 @@ class BaseTestConv:
def run_test_case(self, *args, **kargs):
raise NotImplementedError()
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
)
def test_all(self):
ds = self.default_subsamples
db = self.default_border_mode
......@@ -815,6 +821,10 @@ class BaseTestConv2d(BaseTestConv):
def run_test_case_gi(self, *args, **kwargs):
raise NotImplementedError()
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
)
def test_gradinput_arbitrary_output_shapes(self):
# this computes the grad wrt inputs for an output shape
# that the forward convolution would not produce
......@@ -948,10 +958,7 @@ class BaseTestConv2d(BaseTestConv):
)
@pytest.mark.skipif(
config.cxx == "",
reason="SciPy and cxx needed",
)
@pytest.mark.skipif(config.cxx == "", reason="cxx needed")
class TestAbstractConvNoOptim(BaseTestConv2d):
@classmethod
def setup_class(cls):
......@@ -1884,9 +1891,10 @@ class TestConv2dGrads:
)
@pytest.mark.skipif(
config.cxx == "",
reason="SciPy and cxx needed",
@pytest.mark.skipif(config.cxx == "", reason="cxx needed")
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
)
class TestGroupedConvNoOptim:
conv = abstract_conv.AbstractConv2d
......@@ -2096,9 +2104,10 @@ class TestGroupedConvNoOptim:
utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1)
@pytest.mark.skipif(
config.cxx == "",
reason="SciPy and cxx needed",
@pytest.mark.skipif(config.cxx == "", reason="cxx needed")
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
)
class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim):
conv = abstract_conv.AbstractConv3d
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论