提交 78b41288 authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Add cuDNN version information.

上级 5895d5a2
......@@ -3,7 +3,7 @@
#include <cudnn.h>
static inline const char *cudnnGetErrorString(cudnnStatus_t err) {
inline const char *cudnnGetErrorString(cudnnStatus_t err) {
switch (err) {
case CUDNN_STATUS_SUCCESS:
return "The operation completed successfully.";
......@@ -28,4 +28,14 @@ static inline const char *cudnnGetErrorString(cudnnStatus_t err) {
}
}
static inline const int cudnnVersionMacro(){
#ifdef CUDNN_VERSION
return CUDNN_VERSION;
#else
//CUDNN_VERSION undefined, you probably use cuDNN R1 version
return -1;
#endif
}
#endif
......@@ -3,7 +3,7 @@ import os
import theano
from theano import Apply, gof, tensor
from theano.gof import Optimizer, local_optimizer
from theano.gof.type import CDataType
from theano.gof.type import CDataType, Generic
from theano.compat import PY3
from theano.tensor.nnet import SoftmaxGrad
from theano.sandbox.cuda.type import CudaNdarrayType
......@@ -121,6 +121,39 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
}""" % (error_out,)]
class DnnVersion(DnnBase):
def c_compiler(self):
return NVCC_compiler
def make_node(self):
return Apply(self, [], [Generic()()])
def c_code(self, node, name, inputs, outputs, sub):
o = outputs[0]
return """
%(o)s = PyInt_FromLong(cudnnVersionMacro());
""" % locals()
def do_constant_folding(self, node):
# Needed as we do not want to cache this information.
return False
def c_code_cache_version(self):
# Not needed, but make it clear that we do not want to cache this.
return None
def version():
"""
return the current cuDNN version we compile with.
This only check the header version, the the library we link with.
"""
f = theano.function([], DnnVersion()(),
theano.Mode(optimizer=None))
return f()
class GpuDnnConvDesc(GpuOp):
"""This Op builds a convolution descriptor for use in the other
convolution operations.
......
......@@ -192,3 +192,9 @@ def test_dnn_tag():
assert cuda.dnn.dnn_available()
assert any([isinstance(n.op, cuda.dnn.GpuDnnPool)
for n in f.maker.fgraph.toposort()])
def test_version():
if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg)
assert isinstance(cuda.dnn.version(), int)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论