提交 c52b9396 authored 作者: Frederic Bastien's avatar Frederic Bastien

Don't add twice the Assert about shape in AbstractConv

上级 f1fdacc3
......@@ -1680,19 +1680,20 @@ class AbstractConv2d(AbstractConv):
def grad(self, inp, grads):
bottom, weights = inp
top, = grads
# Don't add the assert again, as it was already added in the forward.
d_bottom = AbstractConv2d_gradInputs(self.imshp, self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(
weights, top, bottom.shape[-2:])
weights, top, bottom.shape[-2:], add_assert_shape=False)
d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(
bottom, top, weights.shape[-2:])
bottom, top, weights.shape[-2:], add_assert_shape=False)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
......@@ -1781,7 +1782,7 @@ class AbstractConv_gradWeights(BaseAbstractConv):
filter_dilation=filter_dilation)
# Update shape/height_width
def make_node(self, img, topgrad, shape):
def make_node(self, img, topgrad, shape, add_assert_shape=True):
# Make sure both inputs are Variables with the same Type
if not isinstance(img, theano.Variable):
img = as_tensor_variable(img)
......@@ -1795,10 +1796,10 @@ class AbstractConv_gradWeights(BaseAbstractConv):
raise TypeError('img must be %dD tensor' % (2 + self.convdim))
if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
img = assert_shape(img, self.imshp,
'AbstractConv_gradWeights shape mismatch: shape of '
'image does not match given imshp.')
if add_assert_shape:
img = assert_shape(img, self.imshp,
'AbstractConv_gradWeights shape mismatch: shape of '
'image does not match given imshp.')
shape = as_tensor_variable(shape)
broadcastable = [topgrad.broadcastable[1],
......@@ -2020,7 +2021,7 @@ class AbstractConv_gradInputs(BaseAbstractConv):
filter_dilation=filter_dilation)
# Update shape/height_width
def make_node(self, kern, topgrad, shape):
def make_node(self, kern, topgrad, shape, add_assert_shape=True):
# Make sure both inputs are Variables with the same Type
if not isinstance(kern, theano.Variable):
kern = as_tensor_variable(kern)
......@@ -2035,9 +2036,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
kern = assert_shape(kern, self.kshp,
'AbstractConv_gradInputs shape mismatch: shape of '
'filters does not match given kshp.')
if add_assert_shape:
kern = assert_shape(kern, self.kshp,
'AbstractConv_gradInputs shape mismatch: shape of '
'filters does not match given kshp.')
shape = as_tensor_variable(shape)
broadcastable = [topgrad.type.broadcastable[0],
......@@ -2158,8 +2160,9 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(bottom, top,
weights.shape[-2:])
self.filter_dilation)(
bottom, top,
weights.shape[-2:])
d_top = AbstractConv2d(self.imshp, self.kshp,
self.border_mode,
self.subsample,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论