Add optimization to lift CPU CTC Op to GPU

上级 f95fd562
...@@ -34,6 +34,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv, ...@@ -34,6 +34,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv,
AbstractConv3d_gradWeights, AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs) AbstractConv3d_gradInputs)
from theano.tensor.nnet.neighbours import Images2Neibs from theano.tensor.nnet.neighbours import Images2Neibs
from theano.tensor.nnet.ctc import ConnectionistTemporalClassification
import theano.tensor.nlinalg as nlinalg import theano.tensor.nlinalg as nlinalg
import theano.tensor.signal.pool as pool import theano.tensor.signal.pool as pool
import theano.tensor.slinalg as slinalg import theano.tensor.slinalg as slinalg
...@@ -78,6 +79,7 @@ from .linalg import (GpuCusolverSolve, MATRIX_STRUCTURES_SOLVE, GpuCholesky, ...@@ -78,6 +79,7 @@ from .linalg import (GpuCusolverSolve, MATRIX_STRUCTURES_SOLVE, GpuCholesky,
cusolver_available, GpuMagmaMatrixInverse, gpu_svd, cusolver_available, GpuMagmaMatrixInverse, gpu_svd,
GpuMagmaCholesky, gpu_qr, GpuMagmaEigh) GpuMagmaCholesky, gpu_qr, GpuMagmaEigh)
from .neighbours import GpuImages2Neibs from .neighbours import GpuImages2Neibs
from .ctc import GpuConnectionistTemporalClassification
_logger = logging.getLogger("theano.gpuarray.opt") _logger = logging.getLogger("theano.gpuarray.opt")
...@@ -2278,6 +2280,14 @@ def local_gpu_magma_svd(op, context_name, inputs, outputs): ...@@ -2278,6 +2280,14 @@ def local_gpu_magma_svd(op, context_name, inputs, outputs):
out = [out.astype('float16')] out = [out.astype('float16')]
return out return out
@register_opt('fast_compile')
@op_lifter([theano.tensor.nnet.ctc.ConnectionistTemporalClassification])
@register_opt2([theano.tensor.nnet.ctc.ConnectionistTemporalClassification], 'fast_compile')
def local_gpu_ctc(op, context_name, inputs, outputs):
if not config.ctc.enabled:
return
return [GpuConnectionistTemporalClassification()(*node.inputs)]
# Do not register in fast_run or fast_compile. # Do not register in fast_run or fast_compile.
# It will be added to fast_run if the GPU is enabled. # It will be added to fast_run if the GPU is enabled.
optdb.register('gpua_scanOp_make_inplace', optdb.register('gpua_scanOp_make_inplace',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论