提交 5f0572d8 authored 作者: slefrancois's avatar slefrancois

edit doc, make tests compatible with numpy<1.10

上级 6478aecf
......@@ -8,6 +8,11 @@ Performs Fast Fourier Transforms (FFT) on the GPU.
FFT gradients are implemented as the opposite Fourier transform of the output gradients.
.. note ::
You must install `scikit-cuda <http://scikit-cuda.readthedocs.io/en/latest>`_
to compute Fourier transforms on the GPU.
.. warning ::
The real and imaginary parts of the Fourier domain arrays are stored as a pair of float32
arrays, emulating complex64. Since theano has limited support for complex
......
......@@ -123,10 +123,9 @@ class TestFFT(unittest.TestCase):
res_rfft_comp = (np.asarray(res_rfft[:, :, :, 0]) +
1j * np.asarray(res_rfft[:, :, :, 1]))
rfft_ref_ortho = numpy.fft.rfftn(inputs_val, axes=(1, 2), norm='ortho')
rfft_ref = numpy.fft.rfftn(inputs_val, axes=(1, 2))
utt.assert_allclose(rfft_ref_ortho, res_rfft_comp,
atol=1e-4, rtol=1e-4)
utt.assert_allclose(rfft_ref / N, res_rfft_comp, atol=1e-4, rtol=1e-4)
# No normalization
rfft = theano.gpuarray.fft.curfft(inputs, norm='no_norm')
......@@ -135,8 +134,7 @@ class TestFFT(unittest.TestCase):
res_rfft_comp = (np.asarray(res_rfft[:, :, :, 0]) +
1j * np.asarray(res_rfft[:, :, :, 1]))
utt.assert_allclose(rfft_ref_ortho * np.sqrt(N * N),
res_rfft_comp, atol=1e-4, rtol=1e-4)
utt.assert_allclose(rfft_ref, res_rfft_comp, atol=1e-4, rtol=1e-4)
# Inverse FFT inputs
inputs_val = np.random.random((1, N, N // 2 + 1, 2)).astype('float32')
......@@ -148,19 +146,16 @@ class TestFFT(unittest.TestCase):
f_irfft = theano.function([], irfft, mode=mode_with_gpu)
res_irfft = f_irfft()
irfft_ref_ortho = numpy.fft.irfftn(
inputs_ref, axes=(1, 2), norm='ortho')
irfft_ref = numpy.fft.irfftn(inputs_ref, axes=(1, 2))
utt.assert_allclose(irfft_ref_ortho,
res_irfft, atol=1e-4, rtol=1e-4)
utt.assert_allclose(irfft_ref * N, res_irfft, atol=1e-4, rtol=1e-4)
# No normalization inverse FFT
irfft = theano.gpuarray.fft.cuirfft(inputs, norm='no_norm')
f_irfft = theano.function([], irfft, mode=mode_with_gpu)
res_irfft = f_irfft()
utt.assert_allclose(irfft_ref_ortho * np.sqrt(N * N),
res_irfft, atol=1e-4, rtol=1e-4)
utt.assert_allclose(irfft_ref * N**2, res_irfft, atol=1e-4, rtol=1e-4)
def test_grad(self):
# The numerical gradient of the FFT is sensitive, must set large
......
......@@ -114,10 +114,9 @@ class TestFFT(unittest.TestCase):
res_rfft_comp = (numpy.asarray(res_rfft[:, :, :, 0]) +
1j * numpy.asarray(res_rfft[:, :, :, 1]))
rfft_ref_ortho = numpy.fft.rfftn(inputs_val, axes=(1, 2), norm='ortho')
rfft_ref = numpy.fft.rfftn(inputs_val, axes=(1, 2))
utt.assert_allclose(rfft_ref_ortho, res_rfft_comp,
atol=1e-4, rtol=1e-4)
utt.assert_allclose(rfft_ref / N, res_rfft_comp, atol=1e-4, rtol=1e-4)
# No normalization
rfft = fft.rfft(inputs, norm='no_norm')
......@@ -126,8 +125,7 @@ class TestFFT(unittest.TestCase):
res_rfft_comp = (numpy.asarray(res_rfft[:, :, :, 0]) +
1j * numpy.asarray(res_rfft[:, :, :, 1]))
utt.assert_allclose(rfft_ref_ortho * numpy.sqrt(N * N),
res_rfft_comp, atol=1e-4, rtol=1e-4)
utt.assert_allclose(rfft_ref, res_rfft_comp, atol=1e-4, rtol=1e-4)
# Inverse FFT inputs
inputs_val = numpy.random.random((1, N, N // 2 + 1, 2))
......@@ -139,18 +137,16 @@ class TestFFT(unittest.TestCase):
f_irfft = theano.function([], irfft)
res_irfft = f_irfft()
irfft_ref_ortho = numpy.fft.irfftn(inputs_ref, axes=(1, 2), norm='ortho')
irfft_ref = numpy.fft.irfftn(inputs_ref, axes=(1, 2))
utt.assert_allclose(irfft_ref_ortho,
res_irfft, atol=1e-4, rtol=1e-4)
utt.assert_allclose(irfft_ref * N, res_irfft, atol=1e-4, rtol=1e-4)
# No normalization inverse FFT
irfft = fft.irfft(inputs, norm='no_norm')
f_irfft = theano.function([], irfft)
res_irfft = f_irfft()
utt.assert_allclose(irfft_ref_ortho * numpy.sqrt(N * N),
res_irfft, atol=1e-4, rtol=1e-4)
utt.assert_allclose(irfft_ref * N**2, res_irfft, atol=1e-4, rtol=1e-4)
def test_params(self):
inputs_val = numpy.random.random((1, N))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论