提交 81eba6e7 authored 作者: Frederic's avatar Frederic

Use CNMeM more frequently

上级 3e40d56a
...@@ -208,22 +208,28 @@ static int SparseBlockGemv_copy(PyArrayObject *a, npy_intp *b) { ...@@ -208,22 +208,28 @@ static int SparseBlockGemv_copy(PyArrayObject *a, npy_intp *b) {
static int %(n)s_prep(int b, int i, int j, int outsize) { static int %(n)s_prep(int b, int i, int j, int outsize) {
int s = b*i*j; int s = b*i*j;
if (%(n)s_list_len < s) { if (%(n)s_list_len < s) {
cudaFree(%(n)s_inp_list); device_free(%(n)s_inp_list);
cudaFree(%(n)s_out_list); device_free(%(n)s_out_list);
cudaFree(%(n)s_W_list); device_free(%(n)s_W_list);
if (cudaMalloc(&%(n)s_inp_list, s*sizeof(float *)) != cudaSuccess) return -1; %(n)s_inp_list = (const float **) device_malloc(s*sizeof(float *));
if (cudaMalloc(&%(n)s_out_list, s*sizeof(float *)) != cudaSuccess) return -1; if (%(n)s_inp_list == NULL) return -1;
if (cudaMalloc(&%(n)s_W_list, s*sizeof(float *)) != cudaSuccess) return -1; %(n)s_out_list = (float **) device_malloc(s*sizeof(float *));
if (%(n)s_out_list == NULL) return -1;
%(n)s_W_list = (const float **) device_malloc(s*sizeof(float *));
if (%(n)s_W_list == NULL) return -1;
%(n)s_list_len = s; %(n)s_list_len = s;
} }
if (%(n)s_iIdx_len < b*i) { if (%(n)s_iIdx_len < b*i) {
cudaFree(%(n)s_iIdx); device_free(%(n)s_iIdx);
if (cudaMalloc(&%(n)s_iIdx, b*i*sizeof(npy_intp)) != cudaSuccess) return -1; %(n)s_iIdx = (npy_intp*) device_malloc(b*i*sizeof(npy_intp));
if (%(n)s_iIdx == NULL) return -1;
%(n)s_iIdx_len = b*i; %(n)s_iIdx_len = b*i;
} }
if (%(n)s_oIdx_len < b*j) { if (%(n)s_oIdx_len < b*j) {
cudaFree(%(n)s_oIdx); device_free(%(n)s_oIdx);
if (cudaMalloc(&%(n)s_oIdx, b*j*sizeof(npy_intp)) != cudaSuccess) return -1; %(n)s_oIdx = (npy_intp*) device_malloc(b*j*sizeof(npy_intp));
if (%(n)s_oIdx == NULL) return -1;
%(n)s_oIdx_len = b*j; %(n)s_oIdx_len = b*j;
} }
return 0; return 0;
...@@ -326,7 +332,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -326,7 +332,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
W=W, fail=sub['fail'], name=nodename) W=W, fail=sub['fail'], name=nodename)
def c_code_cache_version(self): def c_code_cache_version(self):
return (11,) return (12,)
def grad(self, inputs, grads): def grad(self, inputs, grads):
o, W, h, inputIdx, outputIdx = inputs o, W, h, inputIdx, outputIdx = inputs
...@@ -509,24 +515,27 @@ static size_t %(n)s_yIdx_len; ...@@ -509,24 +515,27 @@ static size_t %(n)s_yIdx_len;
static int %(n)s_prep(int b, int i, int j) { static int %(n)s_prep(int b, int i, int j) {
int s = b*i*j; int s = b*i*j;
if (%(n)s_list_len < s) { if (%(n)s_list_len < s) {
cudaFree(%(n)s_x_list); device_free(%(n)s_x_list);
cudaFree(%(n)s_y_list); device_free(%(n)s_y_list);
cudaFree(%(n)s_out_list); device_free(%(n)s_out_list);
if (cudaMalloc(&%(n)s_x_list, s*sizeof(float *)) != cudaSuccess) return -1; %(n)s_x_list = (const float **) device_malloc(s*sizeof(float *));
if (cudaMalloc(&%(n)s_y_list, s*sizeof(float *)) != cudaSuccess) return -1; if (%(n)s_x_list == NULL) return -1;
if (cudaMalloc(&%(n)s_out_list, s*sizeof(float *)) != cudaSuccess) return -1; %(n)s_y_list = (const float **) device_malloc(s*sizeof(float *));
if (%(n)s_y_list == NULL) return -1;
%(n)s_out_list = (float **) device_malloc(s*sizeof(float *));
if (%(n)s_out_list == NULL) return -1;
%(n)s_list_len = s; %(n)s_list_len = s;
} }
if (%(n)s_xIdx_len < b*i) { if (%(n)s_xIdx_len < b*i) {
cudaFree(%(n)s_xIdx); device_free(%(n)s_xIdx);
if (cudaMalloc(&%(n)s_xIdx, b*i*sizeof(npy_intp)) != cudaSuccess) %(n)s_xIdx = (npy_intp*) device_malloc(b*i*sizeof(npy_intp));
return -1; if (%(n)s_xIdx == NULL) return -1;
%(n)s_xIdx_len = b*i; %(n)s_xIdx_len = b*i;
} }
if (%(n)s_yIdx_len < b*j) { if (%(n)s_yIdx_len < b*j) {
cudaFree(%(n)s_yIdx); device_free(%(n)s_yIdx);
if (cudaMalloc(&%(n)s_yIdx, b*j*sizeof(npy_intp)) != cudaSuccess) %(n)s_yIdx = (npy_intp*) device_malloc(b*j*sizeof(npy_intp));
return -1; if (%(n)s_yIdx == NULL) return -1;
%(n)s_yIdx_len = b*j; %(n)s_yIdx_len = b*j;
} }
return 0; return 0;
...@@ -626,7 +635,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -626,7 +635,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
alpha=alpha, fail=sub['fail']) alpha=alpha, fail=sub['fail'])
def c_code_cache_version(self): def c_code_cache_version(self):
return (10,) return (11,)
sparse_block_outer_ss = SparseBlockOuterSS(False) sparse_block_outer_ss = SparseBlockOuterSS(False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论