提交 da939705 authored 作者: carriepl's avatar carriepl

Add opt to convert LogSoftmax to DnnLogSoftmax in gpuarray backend

上级 99012195
...@@ -12,7 +12,7 @@ from theano.gof.cmodule import GCC_compiler ...@@ -12,7 +12,7 @@ from theano.gof.cmodule import GCC_compiler
from theano.gof.type import CDataType, Generic from theano.gof.type import CDataType, Generic
from theano.compile import optdb from theano.compile import optdb
from theano.compile.ops import shape_i from theano.compile.ops import shape_i
from theano.tensor.nnet import SoftmaxGrad from theano.tensor.nnet import LogSoftmax, SoftmaxGrad
from theano.tensor.nnet.abstract_conv import (AbstractConv2d, from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
AbstractConv2d_gradWeights, AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs, AbstractConv2d_gradInputs,
...@@ -1459,7 +1459,7 @@ def local_softmax_dnn(node): ...@@ -1459,7 +1459,7 @@ def local_softmax_dnn(node):
@register_opt('cudnn') @register_opt('cudnn')
@local_optimizer([GpuElemwise]) @local_optimizer([GpuElemwise, LogSoftmax])
def local_log_softmax_dnn(node): def local_log_softmax_dnn(node):
if version() < 3000: if version() < 3000:
# No log-softmax before cudnn v3 # No log-softmax before cudnn v3
...@@ -1474,6 +1474,19 @@ def local_log_softmax_dnn(node): ...@@ -1474,6 +1474,19 @@ def local_log_softmax_dnn(node):
new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode) new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode)
return [new_softmax(softmax_node.inputs[0])] return [new_softmax(softmax_node.inputs[0])]
elif (isinstance(node.op, LogSoftmax) and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, HostFromGpu)):
# Transform the input in the format expected by GpuDnnSoftmax
inp = node.inputs[0].owner.inputs[0]
if inp.ndim != 2:
return
inp = inp.dimshuffle(0, 1, 'x', 'x')
# Apply GpuDnnSoftmax and return the result
out = GpuDnnSoftmax('log', 'channel')(gpu_contiguous(inp))
return [out.dimshuffle(0, 1)]
class NoCuDNNRaise(Optimizer): class NoCuDNNRaise(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论