提交 f44963cb authored 作者: Frederic's avatar Frederic

fix infer_shape in Fourrier when the parameter n is specified.

上级 48f4472f
...@@ -90,7 +90,7 @@ class Fourier(gof.Op): ...@@ -90,7 +90,7 @@ class Fourier(gof.Op):
n = node.inputs[1] n = node.inputs[1]
axis = node.inputs[2] axis = node.inputs[2]
if len(shape_a) == 1: if len(shape_a) == 1:
return (shape_a,) return [(n,)]
elif isinstance(axis, tensor.TensorConstant): elif isinstance(axis, tensor.TensorConstant):
out_shape = list(shape_a[0: axis.data]) + [n] + list(shape_a[axis.data + 1:]) out_shape = list(shape_a[0: axis.data]) + [n] + list(shape_a[axis.data + 1:])
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论