提交 5f89d1ec authored 作者: Frederic's avatar Frederic

Make Fourrier.grad work with the default inputs and test it more.

上级 0e72be06
...@@ -115,12 +115,14 @@ class Fourier(gof.Op): ...@@ -115,12 +115,14 @@ class Fourier(gof.Op):
raise NotImplementedError('%s: gradient is currently implemented' raise NotImplementedError('%s: gradient is currently implemented'
' only for axis being a Theano constant' ' only for axis being a Theano constant'
% self.__class__.__name__) % self.__class__.__name__)
axis = int(axis.data)
# notice that the number of actual elements in wrto is independent of # notice that the number of actual elements in wrto is independent of
# possible padding or truncation: # possible padding or truncation:
ele = tensor.arange(0, tensor.shape(a)[2], 1) ele = tensor.arange(0, tensor.shape(a)[axis], 1)
outer = tensor.outer(ele, ele) outer = tensor.outer(ele, ele)
pow_outer = tensor.exp(((-2 * math.pi * 1j) * outer) / (1. * n)) pow_outer = tensor.exp(((-2 * math.pi * 1j) * outer) / (1. * n))
res = tensor.tensordot(grad, pow_outer, (2, 0)) res = tensor.tensordot(grad, pow_outer, (axis, 0))
return [res, None, None] return [res, None, None]
...@@ -162,14 +164,27 @@ class TestFourier(utt.InferShapeTester): ...@@ -162,14 +164,27 @@ class TestFourier(utt.InferShapeTester):
self.op_class) self.op_class)
def test_gradient(self): def test_gradient(self):
def fft_test(a): def fft_test1(a):
return self.op(a, 16, 1) return self.op(a, 8, 1)
pt = [numpy.random.rand(2, 12, 24)] def fft_test2(a):
return self.op(a, None, 1)
theano.gradient.verify_grad(fft_test, pt, n_tests=1, rng=TestFourier.rng,
eps=None, out_type='complex64', abs_tol=None, def fft_test3(a):
rel_tol=None, mode=None, cast_to_output_type=False) return self.op(a, None, None)
def fft_test4(a):
return self.op(a, 3, None)
pts = [numpy.random.rand(2, 5, 4, 3),
numpy.random.rand(2, 5, 4),
numpy.random.rand(2, 5),
numpy.random.rand(5)]
for fft_test in [fft_test1, fft_test2, fft_test3, fft_test4]:
for pt in pts:
theano.gradient.verify_grad(fft_test, [pt],
n_tests=1, rng=TestFourier.rng,
out_type='complex64')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论