提交 3b77021a authored 作者: Frederic's avatar Frederic

fix gh-3371, disable cudnn v3 rc version and enable it for softmax grad

上级 43cd64bc
...@@ -81,20 +81,28 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) { ...@@ -81,20 +81,28 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
" from one version, but we link with" " from one version, but we link with"
" a different version %s" % str(v)) " a different version %s" % str(v))
raise RuntimeError(dnn_available.msg) raise RuntimeError(dnn_available.msg)
if version() == -1: if v == -1:
dnn_available.avail = False dnn_available.avail = False
dnn_available.msg = ( dnn_available.msg = (
"CuDNN v1 detected. This version is no longer " "CuDNN v1 detected. This version is no longer "
"supported by Theano. Update your CuDNN installation " "supported by Theano. Update your CuDNN installation "
"to a more recent version") "to a more recent version")
raise RuntimeError(dnn_available.msg) raise RuntimeError(dnn_available.msg)
if version() == (20, 20): if v == (20, 20):
dnn_available.avail = False dnn_available.avail = False
dnn_available.msg = ( dnn_available.msg = (
"You have installed a release candidate of CuDNN v2." "You have installed a release candidate of CuDNN v2."
" This isn't supported anymore." " This isn't supported anymore."
" Update to CuDNN v2 final version.") " Update to CuDNN v2 final version.")
raise RuntimeError(dnn_available.msg) raise RuntimeError(dnn_available.msg)
if v[0] >= 3000 and v[0] < 3007:
# 3007 is the final release of cudnn v3
dnn_available.avail = False
dnn_available.msg = (
"You have installed a release candidate of CuDNN v3."
" This isn't supported anymore."
" Update to CuDNN v3 final version.")
raise RuntimeError(dnn_available.msg)
return dnn_available.avail return dnn_available.avail
...@@ -2380,8 +2388,12 @@ if True: ...@@ -2380,8 +2388,12 @@ if True:
isinstance(node.inputs[0].owner.op, HostFromGpu)) or isinstance(node.inputs[0].owner.op, HostFromGpu)) or
(node.inputs[1].owner and (node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, HostFromGpu)))): isinstance(node.inputs[1].owner.op, HostFromGpu)))):
if not dnn_available() or version() != (2000, 2000): v = version()
if v[0] != v[1]:
return
if not dnn_available():
# Softmax grad is broken in v3 rc1 for this case # Softmax grad is broken in v3 rc1 for this case
# But we don't support cudnn v3 rc version now.
return return
ins = [] ins = []
for n in node.inputs: for n in node.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论