Fix CTC lifter optimization

上级 d47162d4
...@@ -2280,13 +2280,14 @@ def local_gpu_magma_svd(op, context_name, inputs, outputs): ...@@ -2280,13 +2280,14 @@ def local_gpu_magma_svd(op, context_name, inputs, outputs):
out = [out.astype('float16')] out = [out.astype('float16')]
return out return out
@register_opt('fast_compile') @register_opt('ctc', 'fast_compile')
@op_lifter([theano.tensor.nnet.ctc.ConnectionistTemporalClassification]) @op_lifter([theano.tensor.nnet.ctc.ConnectionistTemporalClassification])
@register_opt2([theano.tensor.nnet.ctc.ConnectionistTemporalClassification], 'fast_compile') @register_opt2([theano.tensor.nnet.ctc.ConnectionistTemporalClassification], 'ctc', 'fast_compile')
def local_gpu_ctc(op, context_name, inputs, outputs): def local_gpu_ctc(op, context_name, inputs, outputs):
if not config.ctc.enabled: if not config.ctc.enabled:
return return
return [GpuConnectionistTemporalClassification()(*node.inputs)] op = GpuConnectionistTemporalClassification(compute_grad=op.compute_grad)
return list(op(*inputs))
# Do not register in fast_run or fast_compile. # Do not register in fast_run or fast_compile.
# It will be added to fast_run if the GPU is enabled. # It will be added to fast_run if the GPU is enabled.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论