提交 d68f53f8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add more precise output type info to RFFT Op

上级 e9f58c9c
...@@ -14,13 +14,13 @@ class RFFTOp(Op): ...@@ -14,13 +14,13 @@ class RFFTOp(Op):
def output_type(self, inp): def output_type(self, inp):
# add extra dim for real/imag # add extra dim for real/imag
return TensorType(inp.dtype, shape=(None,) * (inp.type.ndim + 1)) return TensorType(inp.dtype, shape=((None,) * inp.type.ndim) + (2,))
def make_node(self, a, s=None): def make_node(self, a, s=None):
a = as_tensor_variable(a) a = as_tensor_variable(a)
if a.ndim < 2: if a.ndim < 2:
raise TypeError( raise TypeError(
f"{self.__class__.__name__}: input must have dimension > 2, with first dimension batches" f"{self.__class__.__name__}: input must have dimension >= 2, with first dimension batches"
) )
if s is None: if s is None:
...@@ -39,9 +39,10 @@ class RFFTOp(Op): ...@@ -39,9 +39,10 @@ class RFFTOp(Op):
a = inputs[0] a = inputs[0]
s = inputs[1] s = inputs[1]
# FIXME: This call is deprecated in numpy 2.0
# axis must be provided when s is not None
A = np.fft.rfftn(a, s=tuple(s)) A = np.fft.rfftn(a, s=tuple(s))
# Format output with two extra dimensions for real and imaginary # Format output with two extra dimensions for real and imaginary parts.
# parts.
out = np.zeros((*A.shape, 2), dtype=a.dtype) out = np.zeros((*A.shape, 2), dtype=a.dtype)
out[..., 0], out[..., 1] = np.real(A), np.imag(A) out[..., 0], out[..., 1] = np.real(A), np.imag(A)
output_storage[0][0] = out output_storage[0][0] = out
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论