提交 f477c6a1 authored 作者: goodfeli's avatar goodfeli

Merge pull request #756 from lamblin/fix_conv_subsampling

Fix conv subsampling
......@@ -80,10 +80,15 @@ class Conv3D(theano.Op):
#quit(-1)
#dCdH = printing.Print("dCdH = ",["shape"])
dCdV = ConvTransp3D.convTransp3D(W, T.zeros_like(V[0,0,0,0,:]), d, dCdH, V.shape[1:4] )
# Make sure the broadcasting pattern of the gradient is the the same
# as the initial variable
dCdV = ConvTransp3D.convTransp3D(W, T.zeros_like(V[0,0,0,0,:]), d, dCdH, V.shape[1:4])
dCdV = T.patternbroadcast(dCdV, V.broadcastable)
WShape = W.shape
dCdW = ConvGrad3D.convGrad3D(V,d,WShape,dCdH)
dCdW = T.patternbroadcast(dCdW, W.broadcastable)
dCdb = T.sum(dCdH, axis=(0,1,2,3))
dCdb = T.patternbroadcast(dCdb, b.broadcastable)
dCdd = None #not differentiable, since d is not continuous
if 'name' in dir(dCdH) and dCdH.name is not None:
......
......@@ -716,6 +716,26 @@ class ConvOp(Op):
if self.imshp != self.imshp_logical or self.kshp != self.kshp_logical:
raise NotImplementedError('todo')
if self.out_mode == 'valid' and (self.dx, self.dy) != (1, 1):
# Use the gradient as defined in conv3D, because the implementation
# by Conv is slow (about 3x slower than conv3D, and probably 10x
# slower than it could be), and incorrect when dx or dy > 2.
# build a "node", that should be equivalent to the one given by
# self.make_node, but using conv3D instead of self.
tmp_node = theano.tensor.nnet.conv3D(
V=inputs.dimshuffle(0, 2, 3, 'x', 1),
W=kerns[:, :, ::-1, ::-1].dimshuffle(0, 2, 3, 'x', 1),
b=theano.tensor.alloc(numpy.asarray(0, dtype=kerns.dtype), kerns.shape[0]),
d=(self.dx, self.dy, 1))
node = theano.tensor.addbroadcast(tmp_node, 3).dimshuffle(0, 4, 1, 2)
# mimic what happens inside theano.grad: get the input gradient
# of the final cost wrt all variables involved.
tmp_gmap = theano.gradient.grad_sources_inputs([(node, gz)], [inputs, kerns])
return [tmp_gmap[inputs], tmp_gmap[kerns]]
if self.dx not in (1, 2) or self.dy not in (1, 2):
raise NotImplementedError("ERROR: We disable ConvOp.grad now when dx or "\
"dy are different from 1 and 2, as there is a bug in it.")
......
......@@ -236,10 +236,11 @@ class TestConv2D(unittest.TestCase):
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 2))
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', subsample=(2, 2))
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 1))
self.validate((1, 1, 6, 6), (1, 1, 3, 3), 'valid', subsample=(3, 3))
# Fails as of 2012-04-12
# Fails as of 2012-07-11
self.assertRaises(NotImplementedError, self.validate, (1, 1, 6, 6),
(1, 1, 3, 3), 'valid', subsample=(3, 3))
(1, 1, 3, 3), 'full', subsample=(3, 3))
def test_shape_Constant_tensor(self):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论