提交 0743dbc0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4176 from nouiz/blas

Make BatchedDot don't try to use c code if there is no blas.
......@@ -2134,6 +2134,10 @@ class BatchedDot(Op):
_z, = out
fail = sub["fail"]
if not config.blas.ldflags:
return super(BatchedDot, self).c_code(node, name,
inp, out, sub)
# generate contiguity condition
def contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
......
......@@ -149,6 +149,8 @@ class BaseCorrMM(gof.Op):
If self.border_mode == 'half', a variable giving the width of the
filters for direction="backprop weights". Ignored otherwise.
"""
if not theano.config.blas.ldflags:
raise NotImplementedError("C code for CorrMM* classes need a blas library.")
dH, dW = self.subsample
if self.border_mode == "half":
padH = padW = -1
......
......@@ -245,14 +245,26 @@ class BaseTestConv2d(unittest.TestCase):
db = (0, 0)
dflip = True in self.filter_flip
dprovide_shape = True in self.provide_shape
skipped = False
for (i, f) in zip(self.inputs_shapes, self.filters_shapes):
for provide_shape in self.provide_shape:
self.tcase(i, f, ds, db, dflip, provide_shape)
try:
self.tcase(i, f, ds, db, dflip, provide_shape)
except SkipTest as e:
skipped = e
for s in self.subsamples:
for b in self.border_modes:
self.tcase(i, f, s, db, dflip, dprovide_shape)
try:
self.tcase(i, f, s, db, dflip, dprovide_shape)
except SkipTest as e:
skipped = e
for flip in self.filter_flip:
self.tcase(i, f, ds, db, flip, dprovide_shape)
try:
self.tcase(i, f, ds, db, flip, dprovide_shape)
except SkipTest as e:
skipped = e
if skipped:
raise skipped
class TestCorrConv2d(BaseTestConv2d):
......@@ -263,6 +275,8 @@ class TestCorrConv2d(BaseTestConv2d):
def tcase(self, i, f, s, b, flip, provide_shape):
o = self.get_output_shape(i, f, s, b)
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
verify_grad=True, provide_shape=provide_shape,
border_mode=b, filter_flip=flip, target_op=CorrMM)
......@@ -307,6 +321,8 @@ class TestCpuConv2d(BaseTestConv2d):
gradinput_OK = False
if fwd_OK:
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
verify_grad=(gradweight_OK and gradinput_OK),
mode=mode, provide_shape=provide_shape,
......@@ -324,6 +340,8 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip)
if gradweight_OK:
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_gradweight(inputs_shape=i, filters_shape=f,
output_shape=o, subsample=s,
verify_grad=False, mode=mode,
......@@ -344,6 +362,8 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip)
if gradinput_OK:
if not theano.config.blas.ldflags:
raise SkipTest("Need blas to test conv2d")
self.run_gradinput(inputs_shape=i, filters_shape=f,
output_shape=o, subsample=s,
verify_grad=False, mode=mode,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论