提交 750f6c02 authored 作者: Guillaume Alain's avatar Guillaume Alain 提交者: Arnaud Bergeron

Added support for images whose last dimension is odd. This is more expensive…

Added support for images whose last dimension is odd. This is more expensive computationally, and it costs more in memory, but it's optional.
上级 b93cc86d
......@@ -374,7 +374,7 @@ def mult_and_reduce(input_fft_v, filters_fft_v, input_shape=None,
def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
border_mode='valid'):
border_mode='valid', want_padding_on_last_input_dimension=False):
"""
expects bc01 input
performs a valid/full convolution
......@@ -383,6 +383,14 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
filters: (oc, ic, f0, f1)
border_mode: 'valid' of 'full'
want_padding_on_last_input_dimension: This code does not support
images for which the last dimension (the "width") is odd. To support
this, you can either pad your images on your own, or call this function
with the `want_padding_on_last_input_dimension` flag set to `True`.
This introduces an extra copying step and consumes memory.
The return value will still be of the appropriate shape because
the padding is trimmed right before the output is returned.
"""
# use symbolic shapes to compute shape info at runtime if not specified
......@@ -400,11 +408,19 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
# pad filters/image to output shape
if border_mode == 'valid':
o0 = i0
o1 = i1
if want_padding_on_last_input_dimension:
o1 = i1 + 1
input_padded = T.zeros((b, ic, o0, o1), dtype='float32')
input_padded = T.set_subtensor(input_padded[:, :, :i0, :i1],
input)
else:
o1 = i1
input_padded = input
filters_padded = T.zeros((oc, ic, o0, o1), dtype='float32')
filters_padded = T.set_subtensor(filters_padded[:, :, :f0, :f1],
filters)
input_padded = input
elif border_mode == 'full':
o0 = i0 + f0 - 1
o1 = i1 + f1 - 1
......@@ -447,7 +463,10 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
# slice because the convolution was circular, we need it to be valid
if border_mode == 'valid':
output = output_circ[:, :, f0 - 1:, f1 - 1:]
if want_padding_on_last_input_dimension:
output = output_circ[:, :, f0 - 1:, f1 - 1:(o1-1)]
else:
output = output_circ[:, :, f0 - 1:, f1 - 1:o1]
else:
output = output_circ
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论