提交 215962f7 authored 作者: Guillaume Alain's avatar Guillaume Alain 提交者: Arnaud Bergeron

border_mode 'full' now working and tested.

上级 51831978
...@@ -418,13 +418,29 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None, ...@@ -418,13 +418,29 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
filters) filters)
elif border_mode == 'full': elif border_mode == 'full':
o0 = i0 + f0 - 1
o1 = i1 + f1 - 1 # In this particular case, the values of (o0, o1) represent
# the dimensions of the work buffer more than the actual dimensions
# of the desired output.
o0 = i0 + 2 * (f0 - 1)
o1 = i1 + 2 * (f1 - 1)
if pad_last_dim:
o1 = o1 + 1
# We line up the filters and the images in a way
# such that the filters are tightly placed against the
# top-left of the array, and the images intersect with
# them on one pixel. The top-left pixel of the images
# is the bottom-right pixel of the filters when we
# do the layout here.
filters_padded = T.zeros((oc, ic, o0, o1), dtype='float32') filters_padded = T.zeros((oc, ic, o0, o1), dtype='float32')
filters_padded = T.set_subtensor(filters_padded[:, :, :f0, :f1], filters_padded = T.set_subtensor(filters_padded[:, :, :f0, :f1],
filters) filters)
input_padded = T.zeros((b, ic, o0, o1), dtype='float32') input_padded = T.zeros((b, ic, o0, o1), dtype='float32')
input_padded = T.set_subtensor(input_padded[:, :, :i0, :i1], input_padded = T.set_subtensor(input_padded[:, :, (f0 - 1):(f0 - 1 + i0), (f1 - 1):(f1 - 1 + i1)],
input) input)
else: else:
raise ValueError('invalid mode') raise ValueError('invalid mode')
...@@ -457,16 +473,21 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None, ...@@ -457,16 +473,21 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
# reshape # reshape
output_circ = output_flat.reshape((b, oc, o0, o1)) # circular! output_circ = output_flat.reshape((b, oc, o0, o1)) # circular!
# slice because the convolution was circular, we need it to be valid # Now we extract the region of interest.
# We just cut it out from the output_circ
# array that was used for the computation.
# We do not need to handle pad_last_dim in a
# special way because we specify explicitly here
# how much values are expected.
if border_mode == 'valid': if border_mode == 'valid':
if pad_last_dim: output = output_circ[:, :, (f0-1):(f0-1 + i0-f0+1), (f1-1):(f1-1 + i1-f1+1)]
output = output_circ[:, :, f0 - 1:, f1 - 1:(o1-1)] elif border_mode == 'full':
else: output = output_circ[:, :, (f0-1):(f0-1 + i0+f0-1), (f1-1):(f1-1 + i1+f1-1)]
output = output_circ[:, :, f0 - 1:, f1 - 1:o1]
else: else:
output = output_circ raise ValueError('invalid mode')
# rescale manually # Rescale manually. This is just a factor that comes in during the
# trip through FFT and inverse FFT.
output = (1.0 / T.cast(o0 * o1, 'float32')) * output output = (1.0 / T.cast(o0 * o1, 'float32')) * output
# output should now be the result of a batched valid convolution # output should now be the result of a batched valid convolution
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论