提交 7f1c3677 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Remove duplicate assert_shape from helper functions.

上级 e0cdd248
......@@ -523,14 +523,6 @@ def conv2d(input,
input = as_tensor_variable(input)
filters = as_tensor_variable(filters)
input = assert_shape(input, input_shape,
'conv2d shape mismatch: shape of '
'input does not match given input_shape.')
filters = assert_shape(filters, filter_shape,
'conv2d shape mismatch: shape of '
'filters does not match given filter_shape.')
conv_op = AbstractConv2d(imshp=input_shape,
kshp=filter_shape,
border_mode=border_mode,
......@@ -630,14 +622,6 @@ def conv3d(input,
input = as_tensor_variable(input)
filters = as_tensor_variable(filters)
input = assert_shape(input, input_shape,
'conv3d shape mismatch: shape of '
'input does not match given input_shape.')
filters = assert_shape(filters, filter_shape,
'conv3d shape mismatch: shape of '
'filters does not match given filter_shape.')
conv_op = AbstractConv3d(imshp=input_shape,
kshp=filter_shape,
border_mode=border_mode,
......@@ -764,9 +748,6 @@ def conv2d_grad_wrt_inputs(output_grad,
for dim in [0, 1, 2, 3]:
assert isinstance(filter_shape[dim], (theano.tensor.TensorConstant,
integer_types, type(None)))
filters = assert_shape(filters, filter_shape,
'conv2d_grad_wrt_inputs shape mismatch: shape of '
'filters does not match given filter_shape.')
# setting the last two dimensions of input_shape to None, if
# the type of these dimensions is TensorVariable.
......@@ -902,9 +883,6 @@ def conv3d_grad_wrt_inputs(output_grad,
for dim in [0, 1, 2, 3, 4]:
assert isinstance(filter_shape[dim], (theano.tensor.TensorConstant,
integer_types, type(None)))
filters = assert_shape(filters, filter_shape,
'conv3d_grad_wrt_inputs shape mismatch: shape of '
'filters does not match given filter_shape.')
# setting the last three dimensions of input_shape to None, if
# the type of these dimensions is TensorVariable.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论