提交 d68f53f8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add more precise output type info to RFFT Op

上级 e9f58c9c
......@@ -14,13 +14,13 @@ class RFFTOp(Op):
def output_type(self, inp):
# add extra dim for real/imag
return TensorType(inp.dtype, shape=(None,) * (inp.type.ndim + 1))
return TensorType(inp.dtype, shape=((None,) * inp.type.ndim) + (2,))
def make_node(self, a, s=None):
a = as_tensor_variable(a)
if a.ndim < 2:
raise TypeError(
f"{self.__class__.__name__}: input must have dimension > 2, with first dimension batches"
f"{self.__class__.__name__}: input must have dimension >= 2, with first dimension batches"
)
if s is None:
......@@ -39,9 +39,10 @@ class RFFTOp(Op):
a = inputs[0]
s = inputs[1]
# FIXME: This call is deprecated in numpy 2.0
# axis must be provided when s is not None
A = np.fft.rfftn(a, s=tuple(s))
# Format output with two extra dimensions for real and imaginary
# parts.
# Format output with two extra dimensions for real and imaginary parts.
out = np.zeros((*A.shape, 2), dtype=a.dtype)
out[..., 0], out[..., 1] = np.real(A), np.imag(A)
output_storage[0][0] = out
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论