提交 3fdad2b7 authored 作者: --global's avatar --global

Make dnn softmax compatible with V2

上级 2ef334c7
...@@ -1728,6 +1728,10 @@ class GpuDnnSoftmaxBase(DnnBase): ...@@ -1728,6 +1728,10 @@ class GpuDnnSoftmaxBase(DnnBase):
DnnBase.__init__(self) DnnBase.__init__(self)
self.tensor_format = tensor_format self.tensor_format = tensor_format
if algo == 'log' and version() < (3000, 3000):
raise RuntimeError("CuDNN's log-softmax implementation is only "
"supported starting at CuDNN v3")
assert(algo in ('fast', 'accurate', 'log')) assert(algo in ('fast', 'accurate', 'log'))
self.algo = algo self.algo = algo
...@@ -1801,11 +1805,11 @@ cudnnStatus_t err%(name)s; ...@@ -1801,11 +1805,11 @@ cudnnStatus_t err%(name)s;
mode = 0 mode = 0
if self.algo == 'fast': if self.algo == 'fast':
algo = 1 algo = "CUDNN_SOFTMAX_FAST"
elif self.algo == "log": elif self.algo == "log":
algo = 2 algo = "CUDNN_SOFTMAX_LOG"
else: else:
algo = 0 algo = "CUDNN_SOFTMAX_ACCURATE"
# Setup configuration variables. # Setup configuration variables.
result = """ result = """
...@@ -1814,11 +1818,7 @@ cudnnTensorFormat_t format%(name)s = CUDNN_TENSOR_NCHW; ...@@ -1814,11 +1818,7 @@ cudnnTensorFormat_t format%(name)s = CUDNN_TENSOR_NCHW;
if (%(tensor_format)d == 1) if (%(tensor_format)d == 1)
format%(name)s = CUDNN_TENSOR_NHWC; format%(name)s = CUDNN_TENSOR_NHWC;
cudnnSoftmaxAlgorithm_t algo%(name)s = CUDNN_SOFTMAX_ACCURATE; cudnnSoftmaxAlgorithm_t algo%(name)s = %(algo)s;
if (%(algo)d == 1)
algo%(name)s = CUDNN_SOFTMAX_FAST;
if (%(algo)d == 2)
algo%(name)s = CUDNN_SOFTMAX_LOG;
cudnnSoftmaxMode_t mode%(name)s = CUDNN_SOFTMAX_MODE_CHANNEL; cudnnSoftmaxMode_t mode%(name)s = CUDNN_SOFTMAX_MODE_CHANNEL;
if (%(mode)d == 1) if (%(mode)d == 1)
...@@ -2202,7 +2202,8 @@ if True: ...@@ -2202,7 +2202,8 @@ if True:
@register_opt('cudnn') @register_opt('cudnn')
@local_optimizer([GpuElemwise]) @local_optimizer([GpuElemwise])
def local_log_softmax_dnn(node): def local_log_softmax_dnn(node):
if not dnn_available(): # The log-softmax implementation is only available starting at CuDNN V3.
if not dnn_available() or version() < (3000, 3000):
return return
if (isinstance(node.op, GpuElemwise) and if (isinstance(node.op, GpuElemwise) and
isinstance(node.op.scalar_op, Log) and isinstance(node.op.scalar_op, Log) and
......
...@@ -468,7 +468,9 @@ def test_log_softmax(): ...@@ -468,7 +468,9 @@ def test_log_softmax():
def test_log_softmax_opt(): def test_log_softmax_opt():
if not cuda.dnn.dnn_available(): # This is a test for an optimization that depends on CuDNN v3 or
# more recent. Don't test if the CuDNN version is too old.
if not cuda.dnn.dnn_available() or cuda.dnn.version() < (3000, 3000):
raise SkipTest(cuda.dnn.dnn_available.msg) raise SkipTest(cuda.dnn.dnn_available.msg)
x = T.ftensor4() x = T.ftensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论