提交 e685f292 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Also check shapes inside the ops.

上级 5caafd99
...@@ -523,14 +523,14 @@ def conv2d(input, ...@@ -523,14 +523,14 @@ def conv2d(input,
input = as_tensor_variable(input) input = as_tensor_variable(input)
filters = as_tensor_variable(filters) filters = as_tensor_variable(filters)
if input_shape is not None:
input = assert_shape(input, input_shape, input = assert_shape(input, input_shape,
'conv2d shape mismatch: shape of ' 'conv2d shape mismatch: shape of '
'input does not match given input_shape.') 'input does not match given input_shape.')
if filter_shape is not None: filters = assert_shape(filters, filter_shape,
filters = assert_shape(filters, filter_shape, 'conv2d shape mismatch: shape of '
'conv2d shape mismatch: shape of ' 'filters does not match given filter_shape.')
'filters does not match given filter_shape.')
conv_op = AbstractConv2d(imshp=input_shape, conv_op = AbstractConv2d(imshp=input_shape,
kshp=filter_shape, kshp=filter_shape,
border_mode=border_mode, border_mode=border_mode,
...@@ -630,14 +630,14 @@ def conv3d(input, ...@@ -630,14 +630,14 @@ def conv3d(input,
input = as_tensor_variable(input) input = as_tensor_variable(input)
filters = as_tensor_variable(filters) filters = as_tensor_variable(filters)
if input_shape is not None:
input = assert_shape(input, input_shape, input = assert_shape(input, input_shape,
'conv3d shape mismatch: shape of ' 'conv3d shape mismatch: shape of '
'input does not match given input_shape.') 'input does not match given input_shape.')
if filter_shape is not None: filters = assert_shape(filters, filter_shape,
filters = assert_shape(filters, filter_shape, 'conv3d shape mismatch: shape of '
'conv3d shape mismatch: shape of ' 'filters does not match given filter_shape.')
'filters does not match given filter_shape.')
conv_op = AbstractConv3d(imshp=input_shape, conv_op = AbstractConv3d(imshp=input_shape,
kshp=filter_shape, kshp=filter_shape,
border_mode=border_mode, border_mode=border_mode,
...@@ -1600,6 +1600,13 @@ class AbstractConv(BaseAbstractConv): ...@@ -1600,6 +1600,13 @@ class AbstractConv(BaseAbstractConv):
if kern.type.ndim != 2 + self.convdim: if kern.type.ndim != 2 + self.convdim:
raise TypeError('kern must be %dD tensor' % (2 + self.convdim)) raise TypeError('kern must be %dD tensor' % (2 + self.convdim))
img = assert_shape(img, self.imshp,
'AbstractConv shape mismatch: shape of '
'image does not match given imshp.')
kern = assert_shape(kern, self.kshp,
'AbstractConv shape mismatch: shape of '
'filters does not match given kshp.')
broadcastable = [img.broadcastable[0], broadcastable = [img.broadcastable[0],
kern.broadcastable[0]] + ([False] * self.convdim) kern.broadcastable[0]] + ([False] * self.convdim)
output = img.type.clone(broadcastable=broadcastable)() output = img.type.clone(broadcastable=broadcastable)()
...@@ -1811,6 +1818,10 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1811,6 +1818,10 @@ class AbstractConv_gradWeights(BaseAbstractConv):
if topgrad.type.ndim != 2 + self.convdim: if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim)) raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
img = assert_shape(img, self.imshp,
'AbstractConv_gradWeights shape mismatch: shape of '
'image does not match given imshp.')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
broadcastable = [topgrad.broadcastable[1], broadcastable = [topgrad.broadcastable[1],
img.broadcastable[1]] + ([False] * self.convdim) img.broadcastable[1]] + ([False] * self.convdim)
...@@ -2046,6 +2057,10 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2046,6 +2057,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
if topgrad.type.ndim != 2 + self.convdim: if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim)) raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
kern = assert_shape(kern, self.kshp,
'AbstractConv_gradInputs shape mismatch: shape of '
'filters does not match given kshp.')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
broadcastable = [topgrad.type.broadcastable[0], broadcastable = [topgrad.type.broadcastable[0],
kern.type.broadcastable[1]] + ([False] * self.convdim) kern.type.broadcastable[1]] + ([False] * self.convdim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论