提交 34223240 authored 作者: --global's avatar --global

Add support for 'log' mode in GpuDnnSoftmax

上级 bab0640a
......@@ -1313,8 +1313,9 @@ class GpuDnnSoftmaxBase(DnnBase):
Op for the cuDNN Softmax.
:param tensor_format: Whether the data format is 'bc01' or 'b01c'.
:param algo: 'fast' or 'accurate' indicating whether computations should be
optimized for speed or accuracy respectively.
:param algo: 'fast', 'accurate' or 'log' indicating whether, respectively,
computations should be optimized for speed, for accuracy, or if CuDNN
should rather compute the log-softmax instead.
:param mode: 'instance' or 'channel' indicating whether the softmax should
be computed per image across 'c01' or per spatial location '01' per
image across 'c'.
......@@ -1327,7 +1328,7 @@ class GpuDnnSoftmaxBase(DnnBase):
DnnBase.__init__(self)
self.tensor_format = tensor_format
assert(algo in ('fast', 'accurate'))
assert(algo in ('fast', 'accurate', 'log'))
self.algo = algo
assert(mode in ('instance', 'channel'))
......@@ -1401,6 +1402,8 @@ cudnnStatus_t err%(name)s;
if self.algo == 'fast':
algo = 1
elif self.algo == "log":
algo = 2
else:
algo = 0
......@@ -1414,6 +1417,8 @@ if (%(tensor_format)d == 1)
cudnnSoftmaxAlgorithm_t algo%(name)s = CUDNN_SOFTMAX_ACCURATE;
if (%(algo)d == 1)
algo%(name)s = CUDNN_SOFTMAX_FAST;
if (%(algo)d == 2)
algo%(name)s = CUDNN_SOFTMAX_LOG;
cudnnSoftmaxMode_t mode%(name)s = CUDNN_SOFTMAX_MODE_CHANNEL;
if (%(mode)d == 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论