提交 fd83d922 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a CUDNNDataType wrapper to handle all the libraries and other such stuff.

上级 38bd170b
...@@ -224,6 +224,31 @@ def dnn_available(context_name): ...@@ -224,6 +224,31 @@ def dnn_available(context_name):
dnn_available.msg = None dnn_available.msg = None
def CUDNNDataType(name, freefunc=None):
hdirs = []
if config.dnn.include_path:
hdirs.append(config.dnn.include_path)
if config.cuda.include_path:
hdirs.append(config.cuda.include_path)
ldirs = []
if config.dnn.library_path:
ldirs.append(config.dnn.library_path)
cargs = []
if config.dnn.bin_path:
if sys.platform == 'darwin':
cargs.append('-Wl,-rpath,' + config.dnn.bin_path)
else:
cargs.append('-Wl,-rpath,"' + config.dnn.bin_path + '"')
return CDataType(name, freefunc,
headers=['cudnn.h'],
header_dirs=hdirs,
libraries=['cudnn'],
lib_dirs=ldirs,
compile_args=cargs,
version=version(raises=False))
class DnnVersion(Op): class DnnVersion(Op):
__props__ = () __props__ = ()
...@@ -311,12 +336,7 @@ def version(raises=True): ...@@ -311,12 +336,7 @@ def version(raises=True):
return version.v return version.v
version.v = None version.v = None
handle_type = CDataType('cudnnHandle_t', 'cudnnDestroy', handle_type = CUDNNDataType('cudnnHandle_t', 'cudnnDestroy')
headers=['cudnn.h'],
header_dirs=[config.dnn.include_path],
libraries=['cudnn'],
lib_dirs=[config.dnn.library_path],
version=version(raises=False))
# Get cuDNN definitions to be used. # Get cuDNN definitions to be used.
cudnn = cudnn_defs.get_definitions(version(raises=False)) cudnn = cudnn_defs.get_definitions(version(raises=False))
...@@ -489,9 +509,8 @@ class GpuDnnConvDesc(COp): ...@@ -489,9 +509,8 @@ class GpuDnnConvDesc(COp):
kern_shape = theano.tensor.basic.cast(kern_shape, 'int64') kern_shape = theano.tensor.basic.cast(kern_shape, 'int64')
node = Apply(self, [kern_shape], node = Apply(self, [kern_shape],
[CDataType("cudnnConvolutionDescriptor_t", [CUDNNDataType("cudnnConvolutionDescriptor_t",
freefunc="cudnnDestroyConvolutionDescriptor", freefunc="cudnnDestroyConvolutionDescriptor")()])
version=version(raises=False))()])
# DebugMode cannot compare the values of CDataType variables, so by # DebugMode cannot compare the values of CDataType variables, so by
# default it returns False all the time. To prevent DebugMode from # default it returns False all the time. To prevent DebugMode from
# complaining because of the MergeOptimizer, we make this variable # complaining because of the MergeOptimizer, we make this variable
...@@ -1301,9 +1320,8 @@ class GpuDnnPoolDesc(Op): ...@@ -1301,9 +1320,8 @@ class GpuDnnPoolDesc(Op):
def make_node(self): def make_node(self):
node = Apply(self, [], node = Apply(self, [],
[CDataType("cudnnPoolingDescriptor_t", [CUDNNDataType("cudnnPoolingDescriptor_t",
freefunc="cudnnDestroyPoolingDescriptor", freefunc="cudnnDestroyPoolingDescriptor")()])
version=version(raises=False))()])
# DebugMode cannot compare the values of CDataType variables, so by # DebugMode cannot compare the values of CDataType variables, so by
# default it returns False all the time. To prevent DebugMode from # default it returns False all the time. To prevent DebugMode from
# complaining because of the MergeOptimizer, we make this variable # complaining because of the MergeOptimizer, we make this variable
...@@ -1961,9 +1979,8 @@ class GpuDnnBatchNormGrad(DnnBase): ...@@ -1961,9 +1979,8 @@ class GpuDnnBatchNormGrad(DnnBase):
return [shape[0], shape[2], shape[2]] return [shape[0], shape[2], shape[2]]
gpudata_type = CDataType('gpudata *', 'gpudata_release') gpudata_type = CDataType('gpudata *', 'gpudata_release')
dropoutdesc_type = CDataType('cudnnDropoutDescriptor_t', dropoutdesc_type = CUDNNDataType('cudnnDropoutDescriptor_t',
'cudnnDestroyDropoutDescriptor', 'cudnnDestroyDropoutDescriptor')
version=version(raises=False))
class GpuDnnDropoutOp(DnnBase): class GpuDnnDropoutOp(DnnBase):
...@@ -2031,9 +2048,8 @@ def dropout(x, dropout=0.0, seed=4242): ...@@ -2031,9 +2048,8 @@ def dropout(x, dropout=0.0, seed=4242):
y, odesc = GpuDnnDropoutOp()(x, desc) y, odesc = GpuDnnDropoutOp()(x, desc)
return y, desc, odesc, states return y, desc, odesc, states
rnndesc_type = CDataType('cudnnRNNDescriptor_t', rnndesc_type = CUDNNDataType('cudnnRNNDescriptor_t',
'cudnnDestroyRNNDescriptor', 'cudnnDestroyRNNDescriptor')
version=version(raises=False))
def as_i32(v): def as_i32(v):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论