提交 6af9d9a0 authored 作者: Frederic's avatar Frederic

small code refactoring

上级 5036835a
......@@ -1107,28 +1107,15 @@ if (%(mode)d == 1)
# Validate the input and build the input variables.
for input_idx, input_name in enumerate(self.softmax_inputs):
result += """
if (!CudaNdarray_is_c_contiguous(%(ins)s)) {
PyErr_SetString(PyExc_ValueError, "Only contiguous inputs are supported.");
%(fail)s
}
result += c_set_tensor4d(ins[input_idx], input_name + "_" + name,
"err" + name, sub['fail'])
err%(name)s = cudnnSetTensor4dDescriptor(
%(input_name)s_%(name)s,
format%(name)s,
CUDNN_DATA_FLOAT,
CudaNdarray_HOST_DIMS(%(ins)s)[0],
CudaNdarray_HOST_DIMS(%(ins)s)[1],
CudaNdarray_HOST_DIMS(%(ins)s)[2],
CudaNdarray_HOST_DIMS(%(ins)s)[3]
);
if (err%(name)s != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "could not set tensor4d descriptor: %%%%s",
cudnnGetErrorString(err%(name)s));
%(fail)s
}
""" % dict(name=name, input_name=input_name,
ins=ins[input_idx], fail=sub['fail'])
subs = dict(ins=ins[-1], outs=outs, fail=sub['fail'],
name=name)
for idx, softmax_input in enumerate(self.softmax_inputs):
subs['name%d' % idx] = softmax_input
subs['ins%d' % idx] = inputs[idx]
# Build and prepare the output variable.
result += """
......@@ -1136,34 +1123,15 @@ if (CudaNdarray_prep_output(&%(outs)s, 4, CudaNdarray_HOST_DIMS(%(ins)s)) != 0)
{
%(fail)s
}
err%(name)s = cudnnSetTensor4dDescriptor(
softmax_output_%(name)s,
format%(name)s,
CUDNN_DATA_FLOAT,
CudaNdarray_HOST_DIMS(%(outs)s)[0],
CudaNdarray_HOST_DIMS(%(outs)s)[1],
CudaNdarray_HOST_DIMS(%(outs)s)[2],
CudaNdarray_HOST_DIMS(%(outs)s)[3]
);
if (err%(name)s != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "could not set out descriptor: %%%%s",
cudnnGetErrorString(err%(name)s));
%(fail)s
}
"""
""" % subs
result += c_set_tensor4d(outs,
"softmax_output_" + name,
"err" + name, sub['fail'])
# Add on a call to the method that does the actual work.
result += self.method()
result += self.method() % subs
subs = dict(ins=ins[-1], outs=outs, fail=sub['fail'],
name=name)
for idx, softmax_input in enumerate(self.softmax_inputs):
subs['name%d' % idx] = softmax_input
subs['ins%d' % idx] = inputs[idx]
return result % subs
return result
def c_code_cache_version(self):
return (0, 6)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论