提交 20f27b69 authored 作者: f0k's avatar f0k

Move tensor descriptor handling out of GpuDnnSoftmaxBase to foster reuse

上级 9ea0f3f2
......@@ -40,6 +40,23 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
AbstractConv2d_gradInputs)
def c_define_tensor_desc(desc):
return """
cudnnTensorDescriptor_t %(desc)s;
""" % dict(desc=desc)
def c_init_tensor_desc(desc, err, fail):
return """
%(desc)s = NULL;
if ((%(err)s = cudnnCreateTensorDescriptor(&%(desc)s)) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not allocate tensor descriptor "
": %%s", cudnnGetErrorString(%(err)s));
%(fail)s
}
""" % dict(desc=desc, err=err, fail=fail)
def c_set_tensor4d(var, desc, err, fail):
return """
{
......@@ -73,6 +90,13 @@ if (%(err)s != CUDNN_STATUS_SUCCESS) {
""" % dict(var=var, err=err, desc=desc, fail=fail)
def c_clean_tensor_desc(desc):
return """
if(%(desc)s!= NULL)
cudnnDestroyTensorDescriptor(%(desc)s);
""" % dict(desc=desc)
class DnnBase(GpuOp, COp):
"""
Creates a handle for cudnn and pulls in the cudnn libraries and headers.
......@@ -2025,31 +2049,10 @@ class GpuDnnSoftmaxBase(DnnBase):
else:
return [shape[1]]
def _define_tensor4d_desc(self, name, id):
return """
cudnnTensorDescriptor_t %(id)s_%(name)s;
""" % dict(name=name, id=id)
def _init_tensor4d_desc(self, name, id, fail):
return """
%(id)s_%(name)s = NULL;
if ((err%(name)s = cudnnCreateTensorDescriptor(&%(id)s_%(name)s)) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not allocate tensor descriptor "
": %%s", cudnnGetErrorString(err%(name)s));
%(fail)s
}
""" % dict(name=name, id=id, fail=fail)
def _clean_tensor4d_desc(self, name, id):
return """
if(%(id)s_%(name)s!= NULL)
cudnnDestroyTensorDescriptor(%(id)s_%(name)s);
""" % dict(name=name, id=id)
def c_support_code_struct(self, node, name):
result = ''
for id in self.tensor_4d_descs:
result += self._define_tensor4d_desc(name, id)
result += c_define_tensor_desc('%s_%s' % (id, name))
return result
def c_init_code_struct(self, node, name, sub):
......@@ -2058,13 +2061,13 @@ cudnnStatus_t err%(name)s;
""" % dict(name=name)
for id in self.tensor_4d_descs:
result += self._init_tensor4d_desc(name, id, sub['fail'])
result += c_init_tensor_desc('%s_%s' % (id, name), 'err' + name, sub['fail'])
return result
def c_cleanup_code_struct(self, node, name):
result = ''
for id in self.tensor_4d_descs:
result += self._clean_tensor4d_desc(name, id)
result += c_clean_tensor_desc('%s_%s' % (id, name))
return result
def c_code(self, node, name, inputs, outputs, sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论