提交 48f4472f authored 作者: Frederic's avatar Frederic

Don't hide the argument to Fourrier in fft() and simplify the infer_shape graph…

Don't hide the argument to Fourrier in fft() and simplify the infer_shape graph when the axis is a constant.
上级 1201efa1
...@@ -89,10 +89,10 @@ class Fourier(gof.Op): ...@@ -89,10 +89,10 @@ class Fourier(gof.Op):
shape_a = in_shapes[0] shape_a = in_shapes[0]
n = node.inputs[1] n = node.inputs[1]
axis = node.inputs[2] axis = node.inputs[2]
#if False and isinstance(axis, tensor.TensorConstant):
# out_shape = list(shape_a[0: axis]) + [n] + list(shape_a[axis + 1:])
if len(shape_a) == 1: if len(shape_a) == 1:
return (shape_a,) return (shape_a,)
elif isinstance(axis, tensor.TensorConstant):
out_shape = list(shape_a[0: axis.data]) + [n] + list(shape_a[axis.data + 1:])
else: else:
l = len(shape_a) l = len(shape_a)
shape_a = tensor.stack(*shape_a) shape_a = tensor.stack(*shape_a)
...@@ -132,10 +132,7 @@ class Fourier(gof.Op): ...@@ -132,10 +132,7 @@ class Fourier(gof.Op):
res = tensor.tensordot(grad, pow_outer, (axis, 0)) res = tensor.tensordot(grad, pow_outer, (axis, 0))
return [res, None, None] return [res, None, None]
fft = Fourier()
def fft(a, n=None, axis=None):
localop = Fourier()
return localop(a, n=None, axis=None)
import numpy import numpy
...@@ -166,8 +163,15 @@ class TestFourier(utt.InferShapeTester): ...@@ -166,8 +163,15 @@ class TestFourier(utt.InferShapeTester):
[numpy.random.rand(12)], [numpy.random.rand(12)],
self.op_class) self.op_class)
a = tensor.dmatrix() a = tensor.dmatrix()
self._compile_and_check([a], [self.op(a, 16, 1)], for var in [self.op(a, 16, 1), self.op(a, None, 1),
[numpy.random.rand(3, 12)], self.op(a, 16, None), self.op(a, None, None)]:
self._compile_and_check([a], [var],
[numpy.random.rand(12, 4)],
self.op_class)
b = tensor.iscalar()
for var in [self.op(a, 16, b), self.op(a, None, b)]:
self._compile_and_check([a, b], [var],
[numpy.random.rand(12, 4), 0],
self.op_class) self.op_class)
def test_gradient(self): def test_gradient(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论