提交 f44963cb authored 作者: Frederic's avatar Frederic

fix infer_shape in Fourrier when the parameter n is specified.

上级 48f4472f
......@@ -90,7 +90,7 @@ class Fourier(gof.Op):
n = node.inputs[1]
axis = node.inputs[2]
if len(shape_a) == 1:
return (shape_a,)
return [(n,)]
elif isinstance(axis, tensor.TensorConstant):
out_shape = list(shape_a[0: axis.data]) + [n] + list(shape_a[axis.data + 1:])
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论