提交 eaef458c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use the conv3D implementation for grad of conv

When using "valid" convolutions with subsampling, the gradient implementation was wrong (if dx or dy > 2) or slow. In that case, we now use the implementation in conv3D instead.
上级 04956d0a
......@@ -716,6 +716,26 @@ class ConvOp(Op):
if self.imshp != self.imshp_logical or self.kshp != self.kshp_logical:
raise NotImplementedError('todo')
if self.out_mode == 'valid' and (self.dx, self.dy) != (1, 1):
# Use the gradient as defined in conv3D, because the implementation
# by Conv is slow (about 3x slower than conv3D, and probably 10x
# slower than it could be), and incorrect when dx or dy > 2.
# build a "node", that should be equivalent to the one given by
# self.make_node, but using conv3D instead of self.
tmp_node = theano.tensor.nnet.conv3D(
V=inputs.dimshuffle(0, 2, 3, 'x', 1),
W=kerns[:, :, ::-1, ::-1].dimshuffle(0, 2, 3, 'x', 1),
b=theano.tensor.alloc(numpy.asarray(0, dtype=kerns.dtype), kerns.shape[0]),
d=(self.dx, self.dy, 1))
node = theano.tensor.addbroadcast(tmp_node, 3).dimshuffle(0, 4, 1, 2)
# mimic what happens inside theano.grad: get the input gradient
# of the final cost wrt all variables involved.
tmp_gmap = theano.gradient.grad_sources_inputs([(node, gz)], [inputs, kerns])
return [tmp_gmap[inputs], tmp_gmap[kerns]]
if self.dx not in (1, 2) or self.dy not in (1, 2):
raise NotImplementedError("ERROR: We disable ConvOp.grad now when dx or "\
"dy are different from 1 and 2, as there is a bug in it.")
......
......@@ -236,10 +236,11 @@ class TestConv2D(unittest.TestCase):
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 2))
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', subsample=(2, 2))
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 1))
self.validate((1, 1, 6, 6), (1, 1, 3, 3), 'valid', subsample=(3, 3))
# Fails as of 2012-04-12
# Fails as of 2012-07-11
self.assertRaises(NotImplementedError, self.validate, (1, 1, 6, 6),
(1, 1, 3, 3), 'valid', subsample=(3, 3))
(1, 1, 3, 3), 'full', subsample=(3, 3))
def test_shape_Constant_tensor(self):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论