提交 297da372 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

AbstractConv_gradInputs.perform should also check the shapes.

上级 59a17284
...@@ -1966,6 +1966,21 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -1966,6 +1966,21 @@ class AbstractConv_gradInputs(BaseAbstractConv):
'"valid", "full", "half", an integer or a tuple of' '"valid", "full", "half", an integer or a tuple of'
' integers'.format(mode)) ' integers'.format(mode))
imshp = self.imshp[:] if self.imshp is not None else [None] * (2 + self.convdim)
fallback_imshp = ([topgrad.shape[0], kern.shape[1]] +
[shape[i] for i in range(self.convdim)])
imshp = [fallback_imshp[i] if imshp[i] is None else imshp[i]
for i in range(2 + self.convdim)]
expected_topgrad_shape = get_conv_output_shape(
imshp, kern.shape,
self.border_mode, self.subsample, self.filter_dilation)
if not tuple(expected_topgrad_shape) == tuple(topgrad.shape):
raise ValueError(
'invalid input_shape for gradInputs: the given input_shape '
'would produce an output of shape {}, but the given topgrad '
'has shape {}'.format(tuple(expected_topgrad_shape),
tuple(topgrad.shape)))
dil_kernshp = tuple((kern.shape[i + 2] - 1) * self.filter_dilation[i] + 1 dil_kernshp = tuple((kern.shape[i + 2] - 1) * self.filter_dilation[i] + 1
for i in range(self.convdim)) for i in range(self.convdim))
pad = (0,) * self.convdim pad = (0,) * self.convdim
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论