提交 04956d0a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make sure gradients have the same type as inputs

上级 44e57b3c
...@@ -80,10 +80,15 @@ class Conv3D(theano.Op): ...@@ -80,10 +80,15 @@ class Conv3D(theano.Op):
#quit(-1) #quit(-1)
#dCdH = printing.Print("dCdH = ",["shape"]) #dCdH = printing.Print("dCdH = ",["shape"])
dCdV = ConvTransp3D.convTransp3D(W, T.zeros_like(V[0,0,0,0,:]), d, dCdH, V.shape[1:4] ) # Make sure the broadcasting pattern of the gradient is the the same
# as the initial variable
dCdV = ConvTransp3D.convTransp3D(W, T.zeros_like(V[0,0,0,0,:]), d, dCdH, V.shape[1:4])
dCdV = T.patternbroadcast(dCdV, V.broadcastable)
WShape = W.shape WShape = W.shape
dCdW = ConvGrad3D.convGrad3D(V,d,WShape,dCdH) dCdW = ConvGrad3D.convGrad3D(V,d,WShape,dCdH)
dCdW = T.patternbroadcast(dCdW, W.broadcastable)
dCdb = T.sum(dCdH, axis=(0,1,2,3)) dCdb = T.sum(dCdH, axis=(0,1,2,3))
dCdb = T.patternbroadcast(dCdb, b.broadcastable)
dCdd = None #not differentiable, since d is not continuous dCdd = None #not differentiable, since d is not continuous
if 'name' in dir(dCdH) and dCdH.name is not None: if 'name' in dir(dCdH) and dCdH.name is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论