conv2d_offset: avoid unnecessary join.

上级 3ebfe09b
......@@ -53,7 +53,11 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
(batch size, nb filters, output row, output col)
"""
if image_shape and filter_shape:
try:
assert image_shape[1]==filter_shape[1]
except:
print 'image ', image_shape, ' filters ', filter_shape
raise
if filter_shape is not None:
nkern = filter_shape[0]
......@@ -149,7 +153,10 @@ def conv2d_offset(input, filters, image_shape=None, filter_shape=None,
outputs.append(out)
# Join the outputs on the leading axis.
output = tensor.join(1, *outputs)
if len(outputs) > 1:
output = tensor.join(1, *outputs)
else:
output = outputs[0]
outshp = ConvOp.getOutputShape(sub_image_shape[2:], filter_shape[2:], subsample, border_mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论