提交 bc1b615b authored 作者: notoraptor's avatar notoraptor

Make elemwise object reusable.

上级 5b276d14
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
cudnnTensorDescriptor_t APPLY_SPECIFIC(input); cudnnTensorDescriptor_t APPLY_SPECIFIC(input);
cudnnTensorDescriptor_t APPLY_SPECIFIC(output); cudnnTensorDescriptor_t APPLY_SPECIFIC(output);
cudnnReduceTensorDescriptor_t APPLY_SPECIFIC(red); cudnnReduceTensorDescriptor_t APPLY_SPECIFIC(red);
GpuElemwise* elemwise;
gpuelemwise_arg arg;
#section init_code_struct #section init_code_struct
...@@ -28,12 +29,18 @@ if ((APPLY_SPECIFIC(err) = cudnnCreateReduceTensorDescriptor(&APPLY_SPECIFIC(red ...@@ -28,12 +29,18 @@ if ((APPLY_SPECIFIC(err) = cudnnCreateReduceTensorDescriptor(&APPLY_SPECIFIC(red
FAIL; FAIL;
} }
elemwise = NULL;
#section cleanup_code_struct #section cleanup_code_struct
if (APPLY_SPECIFIC(input) != NULL) { cudnnDestroyTensorDescriptor(APPLY_SPECIFIC(input)); } if (APPLY_SPECIFIC(input) != NULL) { cudnnDestroyTensorDescriptor(APPLY_SPECIFIC(input)); }
if (APPLY_SPECIFIC(output) != NULL) { cudnnDestroyTensorDescriptor(APPLY_SPECIFIC(output)); } if (APPLY_SPECIFIC(output) != NULL) { cudnnDestroyTensorDescriptor(APPLY_SPECIFIC(output)); }
if (APPLY_SPECIFIC(red) != NULL) { cudnnDestroyReduceTensorDescriptor(APPLY_SPECIFIC(red)); } if (APPLY_SPECIFIC(red) != NULL) { cudnnDestroyReduceTensorDescriptor(APPLY_SPECIFIC(red)); }
if (elemwise) {
GpuElemwise_free(elemwise);
elemwise = NULL;
}
#section support_code_struct #section support_code_struct
...@@ -118,18 +125,18 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input, ...@@ -118,18 +125,18 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input,
case CUDNN_REDUCE_TENSOR_NORM1: case CUDNN_REDUCE_TENSOR_NORM1:
case CUDNN_REDUCE_TENSOR_NORM2: case CUDNN_REDUCE_TENSOR_NORM2:
{ {
gpuelemwise_arg arg; if (elemwise == NULL) {
arg.name = "out"; arg.name = "out";
arg.typecode = (*output)->ga.typecode; arg.typecode = (*output)->ga.typecode;
arg.flags = GE_READ | GE_WRITE; arg.flags = GE_READ | GE_WRITE;
GpuElemwise* elemwise = GpuElemwise_new(c->ctx, "", "out = (out < 0 ? -out : out)", 1, &arg, p, GE_CONVERT_F16); elemwise = GpuElemwise_new(c->ctx, "", "out = (out < 0 ? -out : out)", 1, &arg, p, GE_CONVERT_F16);
if (!elemwise) { if (!elemwise) {
PyErr_SetString(PyExc_RuntimeError, "Unable to create GpuElemwise for output."); PyErr_SetString(PyExc_RuntimeError, "Unable to create GpuElemwise for output.");
return 1; return 1;
} }
}
void* args[1] = { (void*)&(*output)->ga }; void* args[1] = { (void*)&(*output)->ga };
int err = GpuElemwise_call(elemwise, args, 0); int err = GpuElemwise_call(elemwise, args, 0);
GpuElemwise_free(elemwise);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "Unable to call GpuElemwise on output."); PyErr_SetString(PyExc_RuntimeError, "Unable to call GpuElemwise on output.");
return 1; return 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论