Remove decorator register_stabilize from CTC optimization to disable gradients

上级 811ba84b
...@@ -165,7 +165,6 @@ def gpu_ctc(activations, labels, input_lengths): ...@@ -165,7 +165,6 @@ def gpu_ctc(activations, labels, input_lengths):
# Disable gradient computation if not needed # Disable gradient computation if not needed
@register_canonicalize @register_canonicalize
@register_stabilize
@local_optimizer([GpuConnectionistTemporalClassification]) @local_optimizer([GpuConnectionistTemporalClassification])
def local_gpu_ctc_no_grad(node): def local_gpu_ctc_no_grad(node):
if isinstance(node.op, GpuConnectionistTemporalClassification): if isinstance(node.op, GpuConnectionistTemporalClassification):
......
...@@ -227,7 +227,6 @@ def ctc(activations, labels, input_lengths): ...@@ -227,7 +227,6 @@ def ctc(activations, labels, input_lengths):
# Disable gradient computation if not needed # Disable gradient computation if not needed
@register_canonicalize @register_canonicalize
@register_stabilize
@local_optimizer([ConnectionistTemporalClassification]) @local_optimizer([ConnectionistTemporalClassification])
def local_ctc_no_grad(node): def local_ctc_no_grad(node):
if isinstance(node.op, ConnectionistTemporalClassification): if isinstance(node.op, ConnectionistTemporalClassification):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论