Fix lifter optimization of CTC Op to return the apply node's outputs

上级 d904c2e5
...@@ -2301,7 +2301,8 @@ def local_gpu_ctc(op, context_name, inputs, outputs): ...@@ -2301,7 +2301,8 @@ def local_gpu_ctc(op, context_name, inputs, outputs):
if not config.ctc.enabled: if not config.ctc.enabled:
return return
op = GpuConnectionistTemporalClassification(compute_grad=op.compute_grad) op = GpuConnectionistTemporalClassification(compute_grad=op.compute_grad)
return list(op(*inputs)) apply_node = op.make_node(*inputs)
return apply_node.outputs
# Do not register in fast_run or fast_compile. # Do not register in fast_run or fast_compile.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论