提交 6cd85677 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change version() to return a single value and do the check for inconsistency itself.

上级 cf1aa9bf
......@@ -80,13 +80,7 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
else:
# If we can compile, check that we can import and run.
v = version()
if isinstance(v, tuple) and v[0] != v[1]:
dnn_available.avail = False
dnn_available.msg = ("Mixed dnn version. The header is"
" from one version, but we link with"
" a different version %s" % str(v))
raise RuntimeError(dnn_available.msg)
if version() == (20, 20):
if version() == 20:
dnn_available.avail = False
dnn_available.msg = (
"You have installed a release candidate of CuDNN v2."
......@@ -198,11 +192,7 @@ class DnnVersion(Op):
def c_code(self, node, name, inputs, outputs, sub):
o = outputs[0]
return """
#if defined(CUDNN_VERSION)
%(o)s = PyTuple_Pack(2, PyInt_FromLong(CUDNN_VERSION), PyInt_FromLong(cudnnGetVersion()));
#else
%(o)s = PyInt_FromLong(-1);
#endif
""" % locals()
def do_constant_folding(self, node):
......@@ -216,11 +206,9 @@ class DnnVersion(Op):
def version():
"""
Return the current cuDNN version we compile with.
This return a tuple with the header version and the library version we link
with. For older cudnn version without version information, we return -1.
Return the current cuDNN version we link with.
This also does a check that the header version matches the runtime version.
"""
if not dnn_available():
raise Exception(
......@@ -231,7 +219,11 @@ def version():
f = theano.function([], DnnVersion()(),
theano.Mode(optimizer=None),
profile=False)
version.v = f()
v = f()
if v[0] != v[1]:
raise RuntimeError("Mixed dnn version. The header is version %s "
"while the library is version %s." % v)
version.v = v[1]
return version.v
version.v = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论