提交 bec03eb4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix static shapes of outputs in TopKOp

上级 eace7f68
......@@ -414,9 +414,13 @@ class TopKOp(Op):
_check_tensor_is_scalar(kth)
outs = []
if self.return_values:
outs.append(inp.type())
outs.append(
TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)()
)
if self.return_indices:
outs.append(TensorType(dtype=self.idx_dtype, shape=inp.type.shape)())
outs.append(
TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)()
)
return Apply(self, [inp, kth], outs)
def perform(self, node, inputs, output_storage):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论