Fix CTC lifter optimization

上级 d47162d4
......@@ -2280,13 +2280,14 @@ def local_gpu_magma_svd(op, context_name, inputs, outputs):
out = [out.astype('float16')]
return out
@register_opt('fast_compile')
@register_opt('ctc', 'fast_compile')
@op_lifter([theano.tensor.nnet.ctc.ConnectionistTemporalClassification])
@register_opt2([theano.tensor.nnet.ctc.ConnectionistTemporalClassification], 'fast_compile')
@register_opt2([theano.tensor.nnet.ctc.ConnectionistTemporalClassification], 'ctc', 'fast_compile')
def local_gpu_ctc(op, context_name, inputs, outputs):
if not config.ctc.enabled:
return
return [GpuConnectionistTemporalClassification()(*node.inputs)]
op = GpuConnectionistTemporalClassification(compute_grad=op.compute_grad)
return list(op(*inputs))
# Do not register in fast_run or fast_compile.
# It will be added to fast_run if the GPU is enabled.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论