提交 b9468e04 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

XFAIL conv tests of Ops without Python implementation

Mark overly specific tests as xfail
上级 9caa886c
...@@ -3,9 +3,11 @@ import pytest ...@@ -3,9 +3,11 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile import get_default_mode
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.link.numba import NumbaLinker
from pytensor.tensor.conv import abstract_conv from pytensor.tensor.conv import abstract_conv
from pytensor.tensor.conv.abstract_conv import ( from pytensor.tensor.conv.abstract_conv import (
AbstractConv2d, AbstractConv2d,
...@@ -757,6 +759,10 @@ class BaseTestConv: ...@@ -757,6 +759,10 @@ class BaseTestConv:
def run_test_case(self, *args, **kargs): def run_test_case(self, *args, **kargs):
raise NotImplementedError() raise NotImplementedError()
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
)
def test_all(self): def test_all(self):
ds = self.default_subsamples ds = self.default_subsamples
db = self.default_border_mode db = self.default_border_mode
...@@ -815,6 +821,10 @@ class BaseTestConv2d(BaseTestConv): ...@@ -815,6 +821,10 @@ class BaseTestConv2d(BaseTestConv):
def run_test_case_gi(self, *args, **kwargs): def run_test_case_gi(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
)
def test_gradinput_arbitrary_output_shapes(self): def test_gradinput_arbitrary_output_shapes(self):
# this computes the grad wrt inputs for an output shape # this computes the grad wrt inputs for an output shape
# that the forward convolution would not produce # that the forward convolution would not produce
...@@ -948,10 +958,7 @@ class BaseTestConv2d(BaseTestConv): ...@@ -948,10 +958,7 @@ class BaseTestConv2d(BaseTestConv):
) )
@pytest.mark.skipif( @pytest.mark.skipif(config.cxx == "", reason="cxx needed")
config.cxx == "",
reason="SciPy and cxx needed",
)
class TestAbstractConvNoOptim(BaseTestConv2d): class TestAbstractConvNoOptim(BaseTestConv2d):
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
...@@ -1884,9 +1891,10 @@ class TestConv2dGrads: ...@@ -1884,9 +1891,10 @@ class TestConv2dGrads:
) )
@pytest.mark.skipif( @pytest.mark.skipif(config.cxx == "", reason="cxx needed")
config.cxx == "", @pytest.mark.xfail(
reason="SciPy and cxx needed", condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
) )
class TestGroupedConvNoOptim: class TestGroupedConvNoOptim:
conv = abstract_conv.AbstractConv2d conv = abstract_conv.AbstractConv2d
...@@ -2096,9 +2104,10 @@ class TestGroupedConvNoOptim: ...@@ -2096,9 +2104,10 @@ class TestGroupedConvNoOptim:
utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1) utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1)
@pytest.mark.skipif( @pytest.mark.skipif(config.cxx == "", reason="cxx needed")
config.cxx == "", @pytest.mark.xfail(
reason="SciPy and cxx needed", condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Involves Ops with no Python implementation for numba to use as fallback",
) )
class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim): class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim):
conv = abstract_conv.AbstractConv3d conv = abstract_conv.AbstractConv3d
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论