提交 cd049bb6 authored 作者: James Bergstra's avatar James Bergstra

updated convOp grad to use nicer dimshuffle notation

上级 9db4b36a
...@@ -281,8 +281,8 @@ class ConvOp(Op): ...@@ -281,8 +281,8 @@ class ConvOp(Op):
####### Determine gradient on kernels ######## ####### Determine gradient on kernels ########
assert inputs.ndim==4 and kerns.ndim==4 assert inputs.ndim==4 and kerns.ndim==4
newin = tensor.DimShuffle(inputs.broadcastable, (1,0,2,3))(inputs) newin = inputs.dimshuffle((1,0,2,3))
newgz = tensor.DimShuffle(gz.broadcastable, (1,0,2,3))(gz) newgz = gz.dimshuffle((1,0,2,3))
if self.out_mode == 'valid': if self.out_mode == 'valid':
(img, filters) = (newin, newgz) (img, filters) = (newin, newgz)
...@@ -336,12 +336,12 @@ class ConvOp(Op): ...@@ -336,12 +336,12 @@ class ConvOp(Op):
assert (dw.owner.op.outshp==self.kshp).all() assert (dw.owner.op.outshp==self.kshp).all()
if self.out_mode == 'valid': if self.out_mode == 'valid':
# before DimShuffle, dw is of shape visdim x nkern x kshp[0] x kshp[1] # before DimShuffle, dw is of shape visdim x nkern x kshp[0] x kshp[1]
dw = tensor.DimShuffle(dw.broadcastable, (1,0,2,3))(dw) dw = dw.dimshuffle((1,0,2,3))
dw = dw[:,:,::-1,::-1] dw = dw[:,:,::-1,::-1]
####### Determine gradient on inputs ######## ####### Determine gradient on inputs ########
mode = 'valid' if self.out_mode == 'full' else 'full' mode = 'valid' if self.out_mode == 'full' else 'full'
filters = tensor.DimShuffle(kerns.broadcastable, (1,0,2,3))(kerns) filters = kerns.dimshuffle((1,0,2,3))
filters = filters[:,:,::-1,::-1] filters = filters[:,:,::-1,::-1]
nkern = self.imshp[0] nkern = self.imshp[0]
imshp = (self.nkern, self.outshp[0], self.outshp[1]) imshp = (self.nkern, self.outshp[0], self.outshp[1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论