Remove setting the previous node's compute_grad prop to False

上级 f08af27a
...@@ -176,6 +176,5 @@ def local_gpu_ctc_no_grad(node): ...@@ -176,6 +176,5 @@ def local_gpu_ctc_no_grad(node):
if isinstance(node.op, GpuConnectionistTemporalClassification): if isinstance(node.op, GpuConnectionistTemporalClassification):
if len(node.outputs) > 1: if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used if len(node.outputs[1].clients) == 0: # gradient is not used
node.op.compute_grad = False
return [GpuConnectionistTemporalClassification(compute_grad=False)(*node.inputs), None] return [GpuConnectionistTemporalClassification(compute_grad=False)(*node.inputs), None]
return False return False
\ No newline at end of file
...@@ -163,6 +163,5 @@ def local_ctc_no_grad(node): ...@@ -163,6 +163,5 @@ def local_ctc_no_grad(node):
if isinstance(node.op, ConnectionistTemporalClassification): if isinstance(node.op, ConnectionistTemporalClassification):
if len(node.outputs) > 1: if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used if len(node.outputs[1].clients) == 0: # gradient is not used
node.op.compute_grad = False
return [ConnectionistTemporalClassification(compute_grad=False)(*node.inputs), None] return [ConnectionistTemporalClassification(compute_grad=False)(*node.inputs), None]
return False return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论