提交 e685f292 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Also check shapes inside the ops.

上级 5caafd99
......@@ -523,14 +523,14 @@ def conv2d(input,
input = as_tensor_variable(input)
filters = as_tensor_variable(filters)
if input_shape is not None:
input = assert_shape(input, input_shape,
'conv2d shape mismatch: shape of '
'input does not match given input_shape.')
if filter_shape is not None:
filters = assert_shape(filters, filter_shape,
'conv2d shape mismatch: shape of '
'filters does not match given filter_shape.')
input = assert_shape(input, input_shape,
'conv2d shape mismatch: shape of '
'input does not match given input_shape.')
filters = assert_shape(filters, filter_shape,
'conv2d shape mismatch: shape of '
'filters does not match given filter_shape.')
conv_op = AbstractConv2d(imshp=input_shape,
kshp=filter_shape,
border_mode=border_mode,
......@@ -630,14 +630,14 @@ def conv3d(input,
input = as_tensor_variable(input)
filters = as_tensor_variable(filters)
if input_shape is not None:
input = assert_shape(input, input_shape,
'conv3d shape mismatch: shape of '
'input does not match given input_shape.')
if filter_shape is not None:
filters = assert_shape(filters, filter_shape,
'conv3d shape mismatch: shape of '
'filters does not match given filter_shape.')
input = assert_shape(input, input_shape,
'conv3d shape mismatch: shape of '
'input does not match given input_shape.')
filters = assert_shape(filters, filter_shape,
'conv3d shape mismatch: shape of '
'filters does not match given filter_shape.')
conv_op = AbstractConv3d(imshp=input_shape,
kshp=filter_shape,
border_mode=border_mode,
......@@ -1600,6 +1600,13 @@ class AbstractConv(BaseAbstractConv):
if kern.type.ndim != 2 + self.convdim:
raise TypeError('kern must be %dD tensor' % (2 + self.convdim))
img = assert_shape(img, self.imshp,
'AbstractConv shape mismatch: shape of '
'image does not match given imshp.')
kern = assert_shape(kern, self.kshp,
'AbstractConv shape mismatch: shape of '
'filters does not match given kshp.')
broadcastable = [img.broadcastable[0],
kern.broadcastable[0]] + ([False] * self.convdim)
output = img.type.clone(broadcastable=broadcastable)()
......@@ -1811,6 +1818,10 @@ class AbstractConv_gradWeights(BaseAbstractConv):
if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
img = assert_shape(img, self.imshp,
'AbstractConv_gradWeights shape mismatch: shape of '
'image does not match given imshp.')
shape = as_tensor_variable(shape)
broadcastable = [topgrad.broadcastable[1],
img.broadcastable[1]] + ([False] * self.convdim)
......@@ -2046,6 +2057,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
kern = assert_shape(kern, self.kshp,
'AbstractConv_gradInputs shape mismatch: shape of '
'filters does not match given kshp.')
shape = as_tensor_variable(shape)
broadcastable = [topgrad.type.broadcastable[0],
kern.type.broadcastable[1]] + ([False] * self.convdim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论