提交 93b4fb57 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

small fix to make the tests pass + flake8

上级 ecc4e7b9
...@@ -46,13 +46,14 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d): ...@@ -46,13 +46,14 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
provide_shape=provide_shape, border_mode=b, provide_shape=provide_shape, border_mode=b,
filter_flip=flip, target_op=GpuDnnConvGradI) filter_flip=flip, target_op=GpuDnnConvGradI)
class TestCorrMMConv2d(test_abstract_conv.TestConv2d):
class TestCorrMMConv2d(test_abstract_conv.BaseTestConv2d):
def setUp(self): def setUp(self):
super(TestCorrMMConv2d, self).setUp() super(TestCorrMMConv2d, self).setUp()
self.shared = gpu_shared self.shared = gpu_shared
self.mode = mode_with_gpu.excluding('cudnn') self.mode = mode_with_gpu.excluding('cudnn')
def test_gpucorrmm_conv(self, i, f, s, b, flip, provide_shape): def tcase(self, i, f, s, b, flip, provide_shape):
mode = self.mode mode = self.mode
o = self.get_output_shape(i, f, s, b) o = self.get_output_shape(i, f, s, b)
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s, self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
......
...@@ -458,7 +458,6 @@ class BaseAbstractConv2d(Op): ...@@ -458,7 +458,6 @@ class BaseAbstractConv2d(Op):
'invalid mode {}, which must be either ' 'invalid mode {}, which must be either '
'"valid" or "full"'.format(mode)) '"valid" or "full"'.format(mode))
out_shape = get_conv_output_shape(img.shape, kern.shape, mode, [1, 1]) out_shape = get_conv_output_shape(img.shape, kern.shape, mode, [1, 1])
out = numpy.zeros(out_shape, dtype=img.dtype) out = numpy.zeros(out_shape, dtype=img.dtype)
val = _valfrommode(mode) val = _valfrommode(mode)
...@@ -513,7 +512,7 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -513,7 +512,7 @@ class AbstractConv2d(BaseAbstractConv2d):
mode = self.border_mode mode = self.border_mode
if mode == "full": if mode == "full":
mode = (kern.shape[2] - 1, kern.shape[3] - 1) mode = (kern.shape[2] - 1, kern.shape[3] - 1)
elif mode == "half": elif mode == "half":
mode = (kern.shape[2] // 2, kern.shape[3] // 2) mode = (kern.shape[2] // 2, kern.shape[3] // 2)
if isinstance(mode, tuple): if isinstance(mode, tuple):
...@@ -522,16 +521,15 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -522,16 +521,15 @@ class AbstractConv2d(BaseAbstractConv2d):
new_img = numpy.zeros((img.shape[0], img.shape[1], new_img = numpy.zeros((img.shape[0], img.shape[1],
img.shape[2] + 2 * pad_h, img.shape[2] + 2 * pad_h,
img.shape[3] + 2 * pad_w), dtype=img.dtype) img.shape[3] + 2 * pad_w), dtype=img.dtype)
new_img[:, :, pad_h:img.shape[2]+pad_h, pad_w:img.shape[3]+pad_w] = img new_img[:, :, pad_h:img.shape[2] + pad_h, pad_w:img.shape[3] + pad_w] = img
img = new_img img = new_img
if not self.filter_flip: if not self.filter_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
conv_out = self.corr2d(img, kern, mode) conv_out = self.corr2d(img, kern, mode)
conv_out =conv_out[:, :, ::self.subsample[0], ::self.subsample[1]] conv_out = conv_out[:, :, ::self.subsample[0], ::self.subsample[1]]
o[0] = node.outputs[0].type.filter(conv_out) o[0] = node.outputs[0].type.filter(conv_out)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
rval = None rval = None
if eval_points[0] is not None: if eval_points[0] is not None:
...@@ -634,19 +632,18 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -634,19 +632,18 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
mode = self.border_mode mode = self.border_mode
if mode == "full": if mode == "full":
mode = (shape[0] - 1, shape[1] - 1) mode = (shape[0] - 1, shape[1] - 1)
elif mode == "half": elif mode == "half":
mode = (shape[0] // 2, shape[1] // 2) mode = (shape[0] // 2, shape[1] // 2)
if isinstance(mode, tuple): if isinstance(mode, tuple):
pad_h, pad_w = map(int, mode) pad_h, pad_w = map(int, mode)
mode = "valid" mode = "valid"
new_img = numpy.zeros((img.shape[0], img.shape[1], new_img = numpy.zeros((img.shape[0], img.shape[1],
img.shape[2] + 2 *pad_h, img.shape[2] + 2 * pad_h,
img.shape[3] + 2 * pad_w), dtype=img.dtype) img.shape[3] + 2 * pad_w), dtype=img.dtype)
new_img[:, :, pad_h:img.shape[2]+pad_h, pad_w:img.shape[3]+pad_w] = img new_img[:, :, pad_h:img.shape[2] + pad_h, pad_w:img.shape[3] + pad_w] = img
img = new_img img = new_img
if self.subsample[0] > 1 or self.subsample[1] > 1: if self.subsample[0] > 1 or self.subsample[1] > 1:
new_shape = (topgrad.shape[0], topgrad.shape[1], new_shape = (topgrad.shape[0], topgrad.shape[1],
img.shape[2] - shape[0] + 1, img.shape[2] - shape[0] + 1,
...@@ -664,8 +661,6 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -664,8 +661,6 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
kern = kern.transpose(1, 0, 2, 3) kern = kern.transpose(1, 0, 2, 3)
o[0] = node.outputs[0].type.filter(kern) o[0] = node.outputs[0].type.filter(kern)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, top = inp[:2] bottom, top = inp[:2]
weights, = grads weights, = grads
...@@ -764,7 +759,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -764,7 +759,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
if mode == "full": if mode == "full":
pad_h, pad_w = (kern.shape[2] - 1, kern.shape[3] - 1) pad_h, pad_w = (kern.shape[2] - 1, kern.shape[3] - 1)
elif mode == "half": elif mode == "half":
pad_h, pad_w = (kern.shape[2] // 2, kern.shape[3] // 2) pad_h, pad_w = (kern.shape[2] // 2, kern.shape[3] // 2)
elif isinstance(mode, tuple): elif isinstance(mode, tuple):
pad_h, pad_w = map(int, self.border_mode) pad_h, pad_w = map(int, self.border_mode)
mode = "valid" mode = "valid"
...@@ -782,7 +777,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -782,7 +777,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
if self.filter_flip: if self.filter_flip:
img = img[:, :, ::-1, ::-1] img = img[:, :, ::-1, ::-1]
if pad_h > 0 or pad_w > 0: if pad_h > 0 or pad_w > 0:
img = img[:, :, pad_h:img.shape[2]-pad_h, pad_w:img.shape[3]-pad_w] img = img[:, :, pad_h:img.shape[2] - pad_h, pad_w:img.shape[3] - pad_w]
o[0] = node.outputs[0].type.filter(img) o[0] = node.outputs[0].type.filter(img)
def grad(self, inp, grads): def grad(self, inp, grads):
......
...@@ -363,32 +363,31 @@ class TestCpuConv2d(BaseTestConv2d): ...@@ -363,32 +363,31 @@ class TestCpuConv2d(BaseTestConv2d):
border_mode=b, border_mode=b,
filter_flip=flip) filter_flip=flip)
class TestDebugMode(BaseTestConv2d):
class TestDebugModeConv2d(BaseTestConv2d):
def setUp(self): def setUp(self):
super(TestDnnConv2d, self).setUp() super(TestDebugModeConv2d, self).setUp()
self.provide_shape = [False] self.provide_shape = [False]
self.shared = gpu_shared
def tcase(self, i, f, s, b, flip, provide_shape): def tcase(self, i, f, s, b, flip, provide_shape):
mode = "DebugMode" mode = "DebugMode"
o = self.get_output_shape(i, f, s, b) o = self.get_output_shape(i, f, s, b)
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s, self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
verify_grad=True, mode=mode, mode_ref=mode_without_gpu, verify_grad=True, mode=mode,
device='cpu', provide_shape=provide_shape, border_mode=b, provide_shape=provide_shape, border_mode=b,
filter_flip=flip, target_op=None) filter_flip=flip, target_op=None)
self.run_gradweight(inputs_shape=i, filters_shape=f, self.run_gradweight(inputs_shape=i, filters_shape=f,
output_shape=o, subsample=s, output_shape=o, subsample=s,
verify_grad=True, mode=mode, mode_ref=mode_without_gpu, verify_grad=True, mode=mode,
device='cpu', provide_shape=provide_shape, border_mode=b, provide_shape=provide_shape, border_mode=b,
filter_flip=flip, target_op=None) filter_flip=flip, target_op=None)
self.run_gradinput(inputs_shape=i, filters_shape=f, self.run_gradinput(inputs_shape=i, filters_shape=f,
output_shape=o, subsample=s, output_shape=o, subsample=s,
verify_grad=True, mode=mode, mode_ref=mode_without_gpu, verify_grad=True, mode=mode,
device='cpu', provide_shape=provide_shape, border_mode=b, provide_shape=provide_shape, border_mode=b,
filter_flip=flip, target_op=None) filter_flip=flip, target_op=None)
class TestConvTypes(unittest.TestCase): class TestConvTypes(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = tensor.ftensor4() self.input = tensor.ftensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论