提交 c52b9396 authored 作者: Frederic Bastien's avatar Frederic Bastien

Don't add twice the Assert about shape in AbstractConv

上级 f1fdacc3
...@@ -1680,19 +1680,20 @@ class AbstractConv2d(AbstractConv): ...@@ -1680,19 +1680,20 @@ class AbstractConv2d(AbstractConv):
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, weights = inp bottom, weights = inp
top, = grads top, = grads
# Don't add the assert again, as it was already added in the forward.
d_bottom = AbstractConv2d_gradInputs(self.imshp, self.kshp, d_bottom = AbstractConv2d_gradInputs(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation)(
weights, top, bottom.shape[-2:]) weights, top, bottom.shape[-2:], add_assert_shape=False)
d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp, d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation)(
bottom, top, weights.shape[-2:]) bottom, top, weights.shape[-2:], add_assert_shape=False)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
...@@ -1781,7 +1782,7 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1781,7 +1782,7 @@ class AbstractConv_gradWeights(BaseAbstractConv):
filter_dilation=filter_dilation) filter_dilation=filter_dilation)
# Update shape/height_width # Update shape/height_width
def make_node(self, img, topgrad, shape): def make_node(self, img, topgrad, shape, add_assert_shape=True):
# Make sure both inputs are Variables with the same Type # Make sure both inputs are Variables with the same Type
if not isinstance(img, theano.Variable): if not isinstance(img, theano.Variable):
img = as_tensor_variable(img) img = as_tensor_variable(img)
...@@ -1795,10 +1796,10 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1795,10 +1796,10 @@ class AbstractConv_gradWeights(BaseAbstractConv):
raise TypeError('img must be %dD tensor' % (2 + self.convdim)) raise TypeError('img must be %dD tensor' % (2 + self.convdim))
if topgrad.type.ndim != 2 + self.convdim: if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim)) raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
if add_assert_shape:
img = assert_shape(img, self.imshp, img = assert_shape(img, self.imshp,
'AbstractConv_gradWeights shape mismatch: shape of ' 'AbstractConv_gradWeights shape mismatch: shape of '
'image does not match given imshp.') 'image does not match given imshp.')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
broadcastable = [topgrad.broadcastable[1], broadcastable = [topgrad.broadcastable[1],
...@@ -2020,7 +2021,7 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2020,7 +2021,7 @@ class AbstractConv_gradInputs(BaseAbstractConv):
filter_dilation=filter_dilation) filter_dilation=filter_dilation)
# Update shape/height_width # Update shape/height_width
def make_node(self, kern, topgrad, shape): def make_node(self, kern, topgrad, shape, add_assert_shape=True):
# Make sure both inputs are Variables with the same Type # Make sure both inputs are Variables with the same Type
if not isinstance(kern, theano.Variable): if not isinstance(kern, theano.Variable):
kern = as_tensor_variable(kern) kern = as_tensor_variable(kern)
...@@ -2035,9 +2036,10 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2035,9 +2036,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
if topgrad.type.ndim != 2 + self.convdim: if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim)) raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
kern = assert_shape(kern, self.kshp, if add_assert_shape:
'AbstractConv_gradInputs shape mismatch: shape of ' kern = assert_shape(kern, self.kshp,
'filters does not match given kshp.') 'AbstractConv_gradInputs shape mismatch: shape of '
'filters does not match given kshp.')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
broadcastable = [topgrad.type.broadcastable[0], broadcastable = [topgrad.type.broadcastable[0],
...@@ -2158,8 +2160,9 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs): ...@@ -2158,8 +2160,9 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, top, self.filter_dilation)(
weights.shape[-2:]) bottom, top,
weights.shape[-2:])
d_top = AbstractConv2d(self.imshp, self.kshp, d_top = AbstractConv2d(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论