提交 34223240 authored 作者: --global's avatar --global

Add support for 'log' mode in GpuDnnSoftmax

上级 bab0640a
...@@ -1313,8 +1313,9 @@ class GpuDnnSoftmaxBase(DnnBase): ...@@ -1313,8 +1313,9 @@ class GpuDnnSoftmaxBase(DnnBase):
Op for the cuDNN Softmax. Op for the cuDNN Softmax.
:param tensor_format: Whether the data format is 'bc01' or 'b01c'. :param tensor_format: Whether the data format is 'bc01' or 'b01c'.
:param algo: 'fast' or 'accurate' indicating whether computations should be :param algo: 'fast', 'accurate' or 'log' indicating whether, respectively,
optimized for speed or accuracy respectively. computations should be optimized for speed, for accuracy, or if CuDNN
should rather compute the log-softmax instead.
:param mode: 'instance' or 'channel' indicating whether the softmax should :param mode: 'instance' or 'channel' indicating whether the softmax should
be computed per image across 'c01' or per spatial location '01' per be computed per image across 'c01' or per spatial location '01' per
image across 'c'. image across 'c'.
...@@ -1327,7 +1328,7 @@ class GpuDnnSoftmaxBase(DnnBase): ...@@ -1327,7 +1328,7 @@ class GpuDnnSoftmaxBase(DnnBase):
DnnBase.__init__(self) DnnBase.__init__(self)
self.tensor_format = tensor_format self.tensor_format = tensor_format
assert(algo in ('fast', 'accurate')) assert(algo in ('fast', 'accurate', 'log'))
self.algo = algo self.algo = algo
assert(mode in ('instance', 'channel')) assert(mode in ('instance', 'channel'))
...@@ -1401,6 +1402,8 @@ cudnnStatus_t err%(name)s; ...@@ -1401,6 +1402,8 @@ cudnnStatus_t err%(name)s;
if self.algo == 'fast': if self.algo == 'fast':
algo = 1 algo = 1
elif self.algo == "log":
algo = 2
else: else:
algo = 0 algo = 0
...@@ -1414,6 +1417,8 @@ if (%(tensor_format)d == 1) ...@@ -1414,6 +1417,8 @@ if (%(tensor_format)d == 1)
cudnnSoftmaxAlgorithm_t algo%(name)s = CUDNN_SOFTMAX_ACCURATE; cudnnSoftmaxAlgorithm_t algo%(name)s = CUDNN_SOFTMAX_ACCURATE;
if (%(algo)d == 1) if (%(algo)d == 1)
algo%(name)s = CUDNN_SOFTMAX_FAST; algo%(name)s = CUDNN_SOFTMAX_FAST;
if (%(algo)d == 2)
algo%(name)s = CUDNN_SOFTMAX_LOG;
cudnnSoftmaxMode_t mode%(name)s = CUDNN_SOFTMAX_MODE_CHANNEL; cudnnSoftmaxMode_t mode%(name)s = CUDNN_SOFTMAX_MODE_CHANNEL;
if (%(mode)d == 1) if (%(mode)d == 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论