Fix gradient disabler optimization in CTC CPU Op to avoid modifying the node inplace

上级 627b3280
...@@ -168,5 +168,5 @@ def local_ctc_no_grad(node): ...@@ -168,5 +168,5 @@ def local_ctc_no_grad(node):
if isinstance(node.op, ConnectionistTemporalClassification): if isinstance(node.op, ConnectionistTemporalClassification):
if len(node.outputs) > 1: if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used if len(node.outputs[1].clients) == 0: # gradient is not used
node.op = ConnectionistTemporalClassification(compute_grad=False) node.op.compute_grad = False
node.outputs = node.outputs[:1] # costs only return [ConnectionistTemporalClassification(compute_grad=False)(*node.inputs)]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论