提交 1f2a52e2 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

prise en compte troncature, extension et axe interne

上级 4308a8d1
...@@ -65,8 +65,7 @@ class Fourier(gof.Op): ...@@ -65,8 +65,7 @@ class Fourier(gof.Op):
(axis.data < 0 or axis.data > a.ndim - 1)): (axis.data < 0 or axis.data > a.ndim - 1)):
raise TypeError('%s: index of the transformed axis must be' raise TypeError('%s: index of the transformed axis must be'
' a scalar not smaller than 0 and smaller than' ' a scalar not smaller than 0 and smaller than'
' dimension of array' % self.__class__.__name__) ' dimension of array' % self.__class__.__name__)
if n is None: if n is None:
n = a.shape[axis] n = a.shape[axis]
n = tensor.as_tensor_variable(n) n = tensor.as_tensor_variable(n)
...@@ -81,7 +80,6 @@ class Fourier(gof.Op): ...@@ -81,7 +80,6 @@ class Fourier(gof.Op):
raise TypeError('%s: length of the transformed axis must be a' raise TypeError('%s: length of the transformed axis must be a'
' strictly positive scalar' ' strictly positive scalar'
% self.__class__.__name__) % self.__class__.__name__)
return gof.Apply(self, [a, n, axis], [tensor.TensorType('complex128', return gof.Apply(self, [a, n, axis], [tensor.TensorType('complex128',
a.type.broadcastable)()]) a.type.broadcastable)()])
...@@ -92,7 +90,8 @@ class Fourier(gof.Op): ...@@ -92,7 +90,8 @@ class Fourier(gof.Op):
if len(shape_a) == 1: if len(shape_a) == 1:
return [(n,)] return [(n,)]
elif isinstance(axis, tensor.TensorConstant): elif isinstance(axis, tensor.TensorConstant):
out_shape = list(shape_a[0: axis.data]) + [n] + list(shape_a[axis.data + 1:]) out_shape = list(shape_a[0: axis.data]) + [n] +\
list(shape_a[axis.data + 1:])
else: else:
l = len(shape_a) l = len(shape_a)
shape_a = tensor.stack(*shape_a) shape_a = tensor.stack(*shape_a)
...@@ -123,15 +122,33 @@ class Fourier(gof.Op): ...@@ -123,15 +122,33 @@ class Fourier(gof.Op):
' only for axis being a Theano constant' ' only for axis being a Theano constant'
% self.__class__.__name__) % self.__class__.__name__)
axis = int(axis.data) axis = int(axis.data)
# notice that the number of actual elements in wrto is independent of # notice that the number of actual elements in wrto is independent of
# possible padding or truncation: # possible padding or truncation:
ele = tensor.arange(0, tensor.shape(a)[axis], 1) elem = tensor.arange(0, tensor.shape(a)[axis], 1)
outer = tensor.outer(ele, ele) # accounts for padding:
freq = tensor.arange(0, n, 1)
outer = tensor.outer(freq, elem)
pow_outer = tensor.exp(((-2 * math.pi * 1j) * outer) / (1. * n)) pow_outer = tensor.exp(((-2 * math.pi * 1j) * outer) / (1. * n))
res = tensor.tensordot(grad, pow_outer, (axis, 0)) res = tensor.tensordot(grad, pow_outer, (axis, 0))
# This would be simpler but not implemented by theano:
# res = tensor.switch(tensor.lt(n, tensor.shape(a)[axis]),
# tensor.set_subtensor(res[...,n::], 0, False, False), res)
# Instead we resort to that to account for truncation:
flip_shape = list(numpy.arange(0, a.ndim)[::-1])
res = res.dimshuffle(flip_shape)
res = tensor.switch(tensor.lt(n, tensor.shape(a)[axis]),
tensor.set_subtensor(res[n::, ], 0, False, False), res)
res = res.dimshuffle(flip_shape)
# insures that gradient shape conforms to input shape:
out_shape = list(numpy.arange(0, axis)) + [a.ndim - 1] +\
list(numpy.arange(axis, a.ndim - 1))
res = res.dimshuffle(*out_shape)
return [res, None, None] return [res, None, None]
fft = Fourier() fft = Fourier()
...@@ -175,27 +192,27 @@ class TestFourier(utt.InferShapeTester): ...@@ -175,27 +192,27 @@ class TestFourier(utt.InferShapeTester):
self.op_class) self.op_class)
def test_gradient(self): def test_gradient(self):
def fft_test1(a): def fft_test1(a):
return self.op(a, None, None) return self.op(a, None, None)
def fft_test3(a): def fft_test2(a):
return self.op(a, None, 1) return self.op(a, None, 0)
def fft_test2(a): def fft_test3(a):
return self.op(a, 3, None) return self.op(a, 4, None)
def fft_test4(a): def fft_test4(a):
return self.op(a, 8, 1) return self.op(a, 4, 0)
pts = [numpy.random.rand(2, 5, 4, 3), pts = [numpy.random.rand(7, 2, 4, 3),
numpy.random.rand(2, 5, 4), numpy.random.rand(5, 5, 4),
numpy.random.rand(2, 5), numpy.random.rand(2, 9),
numpy.random.rand(5)] numpy.random.rand(5)]
for fft_test in [fft_test1, fft_test2, fft_test3, fft_test4]: for fft_test in [fft_test1, fft_test2, fft_test3, fft_test4]:
for pt in pts: for pt in pts:
theano.gradient.verify_grad(fft_test, [pt], theano.gradient.verify_grad(fft_test, [pt],
n_tests=1, rng=TestFourier.rng, n_tests=1, rng=TestFourier.rng,
out_type='complex64') out_type='complex64')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论