提交 fd83d922 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a CUDNNDataType wrapper to handle all the libraries and other such stuff.

上级 38bd170b
......@@ -224,6 +224,31 @@ def dnn_available(context_name):
dnn_available.msg = None
def CUDNNDataType(name, freefunc=None):
hdirs = []
if config.dnn.include_path:
hdirs.append(config.dnn.include_path)
if config.cuda.include_path:
hdirs.append(config.cuda.include_path)
ldirs = []
if config.dnn.library_path:
ldirs.append(config.dnn.library_path)
cargs = []
if config.dnn.bin_path:
if sys.platform == 'darwin':
cargs.append('-Wl,-rpath,' + config.dnn.bin_path)
else:
cargs.append('-Wl,-rpath,"' + config.dnn.bin_path + '"')
return CDataType(name, freefunc,
headers=['cudnn.h'],
header_dirs=hdirs,
libraries=['cudnn'],
lib_dirs=ldirs,
compile_args=cargs,
version=version(raises=False))
class DnnVersion(Op):
__props__ = ()
......@@ -311,12 +336,7 @@ def version(raises=True):
return version.v
version.v = None
handle_type = CDataType('cudnnHandle_t', 'cudnnDestroy',
headers=['cudnn.h'],
header_dirs=[config.dnn.include_path],
libraries=['cudnn'],
lib_dirs=[config.dnn.library_path],
version=version(raises=False))
handle_type = CUDNNDataType('cudnnHandle_t', 'cudnnDestroy')
# Get cuDNN definitions to be used.
cudnn = cudnn_defs.get_definitions(version(raises=False))
......@@ -489,9 +509,8 @@ class GpuDnnConvDesc(COp):
kern_shape = theano.tensor.basic.cast(kern_shape, 'int64')
node = Apply(self, [kern_shape],
[CDataType("cudnnConvolutionDescriptor_t",
freefunc="cudnnDestroyConvolutionDescriptor",
version=version(raises=False))()])
[CUDNNDataType("cudnnConvolutionDescriptor_t",
freefunc="cudnnDestroyConvolutionDescriptor")()])
# DebugMode cannot compare the values of CDataType variables, so by
# default it returns False all the time. To prevent DebugMode from
# complaining because of the MergeOptimizer, we make this variable
......@@ -1301,9 +1320,8 @@ class GpuDnnPoolDesc(Op):
def make_node(self):
node = Apply(self, [],
[CDataType("cudnnPoolingDescriptor_t",
freefunc="cudnnDestroyPoolingDescriptor",
version=version(raises=False))()])
[CUDNNDataType("cudnnPoolingDescriptor_t",
freefunc="cudnnDestroyPoolingDescriptor")()])
# DebugMode cannot compare the values of CDataType variables, so by
# default it returns False all the time. To prevent DebugMode from
# complaining because of the MergeOptimizer, we make this variable
......@@ -1961,9 +1979,8 @@ class GpuDnnBatchNormGrad(DnnBase):
return [shape[0], shape[2], shape[2]]
gpudata_type = CDataType('gpudata *', 'gpudata_release')
dropoutdesc_type = CDataType('cudnnDropoutDescriptor_t',
'cudnnDestroyDropoutDescriptor',
version=version(raises=False))
dropoutdesc_type = CUDNNDataType('cudnnDropoutDescriptor_t',
'cudnnDestroyDropoutDescriptor')
class GpuDnnDropoutOp(DnnBase):
......@@ -2031,9 +2048,8 @@ def dropout(x, dropout=0.0, seed=4242):
y, odesc = GpuDnnDropoutOp()(x, desc)
return y, desc, odesc, states
rnndesc_type = CDataType('cudnnRNNDescriptor_t',
'cudnnDestroyRNNDescriptor',
version=version(raises=False))
rnndesc_type = CUDNNDataType('cudnnRNNDescriptor_t',
'cudnnDestroyRNNDescriptor')
def as_i32(v):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论