提交 48f4472f authored 作者: Frederic's avatar Frederic

Don't hide the argument to Fourrier in fft() and simplify the infer_shape graph…

Don't hide the argument to Fourrier in fft() and simplify the infer_shape graph when the axis is a constant.
上级 1201efa1
......@@ -89,10 +89,10 @@ class Fourier(gof.Op):
shape_a = in_shapes[0]
n = node.inputs[1]
axis = node.inputs[2]
#if False and isinstance(axis, tensor.TensorConstant):
# out_shape = list(shape_a[0: axis]) + [n] + list(shape_a[axis + 1:])
if len(shape_a) == 1:
return (shape_a,)
elif isinstance(axis, tensor.TensorConstant):
out_shape = list(shape_a[0: axis.data]) + [n] + list(shape_a[axis.data + 1:])
else:
l = len(shape_a)
shape_a = tensor.stack(*shape_a)
......@@ -132,10 +132,7 @@ class Fourier(gof.Op):
res = tensor.tensordot(grad, pow_outer, (axis, 0))
return [res, None, None]
def fft(a, n=None, axis=None):
localop = Fourier()
return localop(a, n=None, axis=None)
fft = Fourier()
import numpy
......@@ -166,9 +163,16 @@ class TestFourier(utt.InferShapeTester):
[numpy.random.rand(12)],
self.op_class)
a = tensor.dmatrix()
self._compile_and_check([a], [self.op(a, 16, 1)],
[numpy.random.rand(3, 12)],
self.op_class)
for var in [self.op(a, 16, 1), self.op(a, None, 1),
self.op(a, 16, None), self.op(a, None, None)]:
self._compile_and_check([a], [var],
[numpy.random.rand(12, 4)],
self.op_class)
b = tensor.iscalar()
for var in [self.op(a, 16, b), self.op(a, None, b)]:
self._compile_and_check([a, b], [var],
[numpy.random.rand(12, 4), 0],
self.op_class)
def test_gradient(self):
def fft_test1(a):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论