提交 3b77021a authored 作者: Frederic's avatar Frederic

fix gh-3371, disable cudnn v3 rc version and enable it for softmax grad

上级 43cd64bc
......@@ -81,20 +81,28 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
" from one version, but we link with"
" a different version %s" % str(v))
raise RuntimeError(dnn_available.msg)
if version() == -1:
if v == -1:
dnn_available.avail = False
dnn_available.msg = (
"CuDNN v1 detected. This version is no longer "
"supported by Theano. Update your CuDNN installation "
"to a more recent version")
raise RuntimeError(dnn_available.msg)
if version() == (20, 20):
if v == (20, 20):
dnn_available.avail = False
dnn_available.msg = (
"You have installed a release candidate of CuDNN v2."
" This isn't supported anymore."
" Update to CuDNN v2 final version.")
raise RuntimeError(dnn_available.msg)
if v[0] >= 3000 and v[0] < 3007:
# 3007 is the final release of cudnn v3
dnn_available.avail = False
dnn_available.msg = (
"You have installed a release candidate of CuDNN v3."
" This isn't supported anymore."
" Update to CuDNN v3 final version.")
raise RuntimeError(dnn_available.msg)
return dnn_available.avail
......@@ -2380,8 +2388,12 @@ if True:
isinstance(node.inputs[0].owner.op, HostFromGpu)) or
(node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, HostFromGpu)))):
if not dnn_available() or version() != (2000, 2000):
v = version()
if v[0] != v[1]:
return
if not dnn_available():
# Softmax grad is broken in v3 rc1 for this case
# But we don't support cudnn v3 rc version now.
return
ins = []
for n in node.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论