提交 7ec0c018 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Attempt at adding support for the 'full' border mode.

上级 a50a7986
...@@ -373,13 +373,16 @@ def mult_and_reduce(input_fft_v, filters_fft_v, input_shape=None, ...@@ -373,13 +373,16 @@ def mult_and_reduce(input_fft_v, filters_fft_v, input_shape=None,
return output return output
def conv2d_fft(input, filters, image_shape=None, filter_shape=None): def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
border_mode='valid'):
""" """
expects bc01 input expects bc01 input
performs a valid convolution performs a valid/full convolution
input: (b, ic, i0, i1) input: (b, ic, i0, i1)
filters: (oc, ic, f0, f1) filters: (oc, ic, f0, f1)
bocder_mode: 'valid' of 'full'
""" """
# use symbolic shapes to compute shape info at runtime if not specified # use symbolic shapes to compute shape info at runtime if not specified
...@@ -394,44 +397,63 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None): ...@@ -394,44 +397,63 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None):
# output channels, input channels, filter dim 0, filter dim 1 # output channels, input channels, filter dim 0, filter dim 1
oc, ic_, f0, f1 = filter_shape oc, ic_, f0, f1 = filter_shape
# pad filters to input shape # pad filters/image to output shape
filters_padded = T.zeros((oc, ic, i0, i1)) if border_mode == 'valid':
filters_padded = T.set_subtensor(filters_padded[:, :, :f0, :f1], filters) o0 = i0
o1 = i1
filters_padded = T.zeros((oc, ic, o0, o1), dtype='float32')
filters_padded = T.set_subtensor(filters_padded[:, :, :f0, :f1],
filters)
input_padded = input
elif mode == 'full':
o0 = i0 + f0 - 1
o1 = i1 + f1 - 1
filters_padded = T.zeros((oc, ic, o0, o1), dtype='float32')
filters_padded = T.set_subtensor(filters_padded[:, :, :f0, :f1],
filters)
input_padded = T.zeros((oc, ic, o0, o1), dtype='float32')
input_padded = T.set_subtensor(input_padded[:, :, :i0, :i1],
input)
else:
raise ValueError('invalid mode')
# reshape for FFT # reshape for FFT
input_flat = input.reshape((b * ic, i0, i1)) input_flat = input_padded.reshape((b * ic, o0, o1))
filters_flat = filters_padded.reshape((oc * ic, i0, i1)) filters_flat = filters_padded.reshape((oc * ic, o0, o1))
# perform FFT # perform FFT
input_fft_flat = cufft(input_flat) # (b * ic, i0, i1//2 + 1, 2) input_fft_flat = cufft(input_flat) # (b * ic, o0, o1//2 + 1, 2)
filters_fft_flat = cufft(filters_flat) # (oc * ic, i0, i1//2 + 1, 2) filters_fft_flat = cufft(filters_flat) # (oc * ic, o0, o1//2 + 1, 2)
# unfold ic dimension # unfold ic dimension
input_fft_v_shape = (b, ic, i0, i1 // 2 + 1, 2) input_fft_v_shape = (b, ic, o0, o1 // 2 + 1, 2)
filters_fft_v_shape = (oc, ic, i0, i1 // 2 + 1, 2) filters_fft_v_shape = (oc, ic, o0, o1 // 2 + 1, 2)
input_fft_v = input_fft_flat.reshape(input_fft_v_shape) input_fft_v = input_fft_flat.reshape(input_fft_v_shape)
filters_fft_v = filters_fft_flat.reshape(filters_fft_v_shape) filters_fft_v = filters_fft_flat.reshape(filters_fft_v_shape)
# (b, oc, i0, i1//2 + 1, 2) # (b, oc, o0, o1//2 + 1, 2)
output_fft_s = mult_and_reduce(input_fft_v, filters_fft_v, output_fft_s = mult_and_reduce(input_fft_v, filters_fft_v,
input_shape=input_fft_v_shape, input_shape=input_fft_v_shape,
filter_shape=filters_fft_v_shape) filter_shape=filters_fft_v_shape)
# reshape for IFFT # reshape for IFFT
output_fft_flat = output_fft_s.reshape((b * oc, i0, i1 // 2 + 1, 2)) output_fft_flat = output_fft_s.reshape((b * oc, o0, o1 // 2 + 1, 2))
# perform IFFT # perform IFFT
output_flat = cuifft(output_fft_flat) # (b * oc, i0, i1) output_flat = cuifft(output_fft_flat) # (b * oc, o0, o1)
# reshape # reshape
output_circ = output_flat.reshape((b, oc, i0, i1)) # circular! output_circ = output_flat.reshape((b, oc, o0, o1)) # circular!
# slice because the convolution was circular, we need it to be valid # slice because the convolution was circular, we need it to be valid
output = output_circ[:, :, f0 - 1:, f1 - 1:] if border_mode == 'valid':
output = output_circ[:, :, f0 - 1:, f1 - 1:]
else:
output = output_circ
# rescale manually # rescale manually
output = (1.0 / T.cast(i0 * i1, theano.config.floatX)) * output output = (1.0 / T.cast(o0 * o1, 'float32')) * output
# output should now be the result of a batched valid convolution # output should now be the result of a batched valid convolution
# of the input with the filters. # of the input with the filters.
return output return basic_ops.as_cuda_ndarray_variable(output)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论