提交 cd049bb6 authored 作者: James Bergstra's avatar James Bergstra

updated convOp grad to use nicer dimshuffle notation

上级 9db4b36a
......@@ -281,8 +281,8 @@ class ConvOp(Op):
####### Determine gradient on kernels ########
assert inputs.ndim==4 and kerns.ndim==4
newin = tensor.DimShuffle(inputs.broadcastable, (1,0,2,3))(inputs)
newgz = tensor.DimShuffle(gz.broadcastable, (1,0,2,3))(gz)
newin = inputs.dimshuffle((1,0,2,3))
newgz = gz.dimshuffle((1,0,2,3))
if self.out_mode == 'valid':
(img, filters) = (newin, newgz)
......@@ -336,12 +336,12 @@ class ConvOp(Op):
assert (dw.owner.op.outshp==self.kshp).all()
if self.out_mode == 'valid':
# before DimShuffle, dw is of shape visdim x nkern x kshp[0] x kshp[1]
dw = tensor.DimShuffle(dw.broadcastable, (1,0,2,3))(dw)
dw = dw.dimshuffle((1,0,2,3))
dw = dw[:,:,::-1,::-1]
####### Determine gradient on inputs ########
mode = 'valid' if self.out_mode == 'full' else 'full'
filters = tensor.DimShuffle(kerns.broadcastable, (1,0,2,3))(kerns)
filters = kerns.dimshuffle((1,0,2,3))
filters = filters[:,:,::-1,::-1]
nkern = self.imshp[0]
imshp = (self.nkern, self.outshp[0], self.outshp[1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论