Simplify name of GPU CTC optimization to disable gradients

上级 e4fe9d51
...@@ -171,7 +171,7 @@ def gpu_ctc(activations, labels, input_lengths): ...@@ -171,7 +171,7 @@ def gpu_ctc(activations, labels, input_lengths):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([GpuConnectionistTemporalClassification]) @local_optimizer([GpuConnectionistTemporalClassification])
def local_GpuConnectionistTemporalClassification_no_grad(node): def local_gpu_ctc_no_grad(node):
if isinstance(node.op, GpuConnectionistTemporalClassification): if isinstance(node.op, GpuConnectionistTemporalClassification):
if len(node.outputs) > 1: if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used if len(node.outputs[1].clients) == 0: # gradient is not used
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论