提交 1ef9312b authored 作者: Frederic Bastien's avatar Frederic Bastien

Raise NotImplementedError or SkipTest for CorrMM* op

上级 df20dcfd
......@@ -149,6 +149,8 @@ class BaseCorrMM(gof.Op):
If self.border_mode == 'half', a variable giving the width of the
filters for direction="backprop weights". Ignored otherwise.
"""
if not theano.config.blas.ldflags:
raise NotImplementedError("C code for CorrMM* classes need a blas library.")
dH, dW = self.subsample
if self.border_mode == "half":
padH = padW = -1
......
......@@ -245,15 +245,26 @@ class BaseTestConv2d(unittest.TestCase):
db = (0, 0)
dflip = True in self.filter_flip
dprovide_shape = True in self.provide_shape
skipped = False
for (i, f) in zip(self.inputs_shapes, self.filters_shapes):
for provide_shape in self.provide_shape:
self.tcase(i, f, ds, db, dflip, provide_shape)
try:
self.tcase(i, f, ds, db, dflip, provide_shape)
except SkipTest as e:
skipped = e
for s in self.subsamples:
for b in self.border_modes:
self.tcase(i, f, s, db, dflip, dprovide_shape)
try:
self.tcase(i, f, s, db, dflip, dprovide_shape)
except SkipTest as e:
skipped = e
for flip in self.filter_flip:
self.tcase(i, f, ds, db, flip, dprovide_shape)
try:
self.tcase(i, f, ds, db, flip, dprovide_shape)
except SkipTest as e:
skipped = e
if skipped:
raise e
class TestCorrConv2d(BaseTestConv2d):
def setUp(self):
......@@ -263,6 +274,8 @@ class TestCorrConv2d(BaseTestConv2d):
def tcase(self, i, f, s, b, flip, provide_shape):
o = self.get_output_shape(i, f, s, b)
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
verify_grad=True, provide_shape=provide_shape,
border_mode=b, filter_flip=flip, target_op=CorrMM)
......@@ -307,6 +320,8 @@ class TestCpuConv2d(BaseTestConv2d):
gradinput_OK = False
if fwd_OK:
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
verify_grad=(gradweight_OK and gradinput_OK),
mode=mode, provide_shape=provide_shape,
......@@ -324,6 +339,8 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip)
if gradweight_OK:
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_gradweight(inputs_shape=i, filters_shape=f,
output_shape=o, subsample=s,
verify_grad=False, mode=mode,
......@@ -344,6 +361,8 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip)
if gradinput_OK:
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_gradinput(inputs_shape=i, filters_shape=f,
output_shape=o, subsample=s,
verify_grad=False, mode=mode,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论