提交 1705a123 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix broadcastable pattern of gradient in abstract conv

上级 2bac5a41
...@@ -5,7 +5,7 @@ Define abstract conv2d interface ...@@ -5,7 +5,7 @@ Define abstract conv2d interface
import logging import logging
import theano import theano
from theano.tensor import as_tensor_variable from theano.tensor import as_tensor_variable, patternbroadcast
from theano.gof import Apply, Op from theano.gof import Apply, Op
...@@ -314,6 +314,12 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -314,6 +314,12 @@ class AbstractConv2d(BaseAbstractConv2d):
self.filter_flip)( self.filter_flip)(
bottom, top, weights.shape[-2:]) bottom, top, weights.shape[-2:])
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
d_bottom = patternbroadcast(d_bottom, bottom.broadcastable)
d_weights = patternbroadcast(d_weights, weights.broadcastable)
return d_bottom, d_weights return d_bottom, d_weights
...@@ -369,6 +375,12 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -369,6 +375,12 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip)(bottom, weights) self.filter_flip)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
d_bottom = patternbroadcast(d_bottom, bottom.broadcastable)
d_top = patternbroadcast(d_top, top.broadcastable)
d_height_width = (theano.gradient.DisconnectedType()(),) d_height_width = (theano.gradient.DisconnectedType()(),)
return (d_bottom, d_top) + d_height_width return (d_bottom, d_top) + d_height_width
...@@ -425,6 +437,12 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -425,6 +437,12 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
d_top = AbstractConv2d(self.imshp, self.kshp, d_top = AbstractConv2d(self.imshp, self.kshp,
self.border_mode, self.subsample)( self.border_mode, self.subsample)(
bottom, weights) bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
d_weights = patternbroadcast(d_weights, weights.broadcastable)
d_top = patternbroadcast(d_top, top.broadcastable)
d_height_width = (theano.gradient.DisconnectedType()(),) d_height_width = (theano.gradient.DisconnectedType()(),)
return (d_weights, d_top) + d_height_width return (d_weights, d_top) + d_height_width
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论