Change CTC lifter optimization to return apply node's outputs

上级 e24fddf0
......@@ -2299,10 +2299,7 @@ def local_gpu_magma_svd(op, context_name, inputs, outputs):
@register_opt2([ConnectionistTemporalClassification], 'ctc', 'fast_compile')
def local_gpu_ctc(op, context_name, inputs, outputs):
op = GpuConnectionistTemporalClassification(compute_grad=op.compute_grad)
if op.compute_grad:
# Circumvent assert error on condition len(outputs) == len(node.outputs)
op.default_output = None
return op
return op.make_node(*inputs).outputs
# Do not register in fast_run or fast_compile.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论