提交 36c30634 authored 作者: Vivek Kulkarni's avatar Vivek Kulkarni

Addressing code review comments obtained. Moving the code to…

Addressing code review comments obtained. Moving the code to c_support_apply_code and handling the cases for launching the threads and blocks to be minimum of thresh-hold or what is required
上级 9b67359d
...@@ -2446,9 +2446,12 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2446,9 +2446,12 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
return (1,) return (1,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
active_device_no = theano.sandbox.cuda.active_device_number()
compute_capability = theano.sandbox.cuda.device_properties(active_device_no)['major']
if (self.set_instead_of_inc) or \ if (self.set_instead_of_inc) or \
(node.inputs[0].ndim != node.inputs[1].ndim): (node.inputs[0].ndim != node.inputs[1].ndim) or \
raise NotImplementedError("This case does not have C code yet.") (compute_capability < 2):
raise NotImplementedError("This case does not have C code yet.")
x = inputs[0] x = inputs[0]
y = inputs[1] y = inputs[1]
...@@ -2473,6 +2476,81 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2473,6 +2476,81 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
} }
""" %locals() """ %locals()
def c_support_code_apply(self, node, nodename):
return """
__global__ void k_vector_add_fast(int numRowsX,
int numColsX,
int stridesX0,
int stridesX1,
float *X,
int numRowsY,
int numColsY,
int stridesY0,
int stridesY1,
float *Y ,
long *d_indices_arr,
int num)
{
for (int i = (blockIdx.x); i < num; i += gridDim.x)
{
for(int j = (threadIdx.x); j < numColsX;j += blockDim.x)
{
int x_row = d_indices_arr[i];
int y_row = i;
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)], Y[(y_row * stridesY0) + (j * stridesY1)]);
}
}
return;
}
void CudaNdarray_vector_add_fast(CudaNdarray* py_self, CudaNdarray* py_other, PyArrayObject *indices_arr)
{
const int *shapeX = CudaNdarray_HOST_DIMS(py_self);
const int *shapeY = CudaNdarray_HOST_DIMS(py_other);
const int *strX = CudaNdarray_HOST_STRIDES(py_self);
const int *strY = CudaNdarray_HOST_STRIDES(py_other);
unsigned int size = (unsigned int)PyArray_SIZE(indices_arr);
unsigned int numcolsX = shapeX[1];
unsigned int num_threads_per_block = std::min(numcolsX, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
unsigned int num_blocks = std::min(size ,(unsigned int)NUM_VECTOR_OP_BLOCKS);
dim3 n_blocks(num_blocks);
dim3 n_threads(num_threads_per_block);
long *d_indices_arr = NULL;
PyArrayObject *cpu_indices_arr = PyArray_GETCONTIGUOUS(indices_arr);
d_indices_arr = (long *)device_malloc(PyArray_NBYTES(cpu_indices_arr));
assert(d_indices_arr);
cudaError_t err = cudaMemcpy(d_indices_arr,
PyArray_DATA(cpu_indices_arr),
PyArray_NBYTES(cpu_indices_arr),
cudaMemcpyHostToDevice);
assert(err == cudaSuccess);
k_vector_add_fast<<<n_blocks, n_threads>>>(shapeX[0],
shapeX[1],
strX[0],
strX[1],
CudaNdarray_DEV_DATA(py_self),
shapeY[0],
shapeY[1],
strY[0],
strY[1],
CudaNdarray_DEV_DATA(py_other),
d_indices_arr,
PyArray_SIZE(indices_arr)
);
device_free(d_indices_arr);
return;
}
""" %locals()
class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
""" """
Implement IncSubtensor on the gpu. Implement IncSubtensor on the gpu.
......
...@@ -1326,32 +1326,6 @@ CudaNdarray_add(PyObject* py_self, PyObject * py_other) ...@@ -1326,32 +1326,6 @@ CudaNdarray_add(PyObject* py_self, PyObject * py_other)
return (PyObject *) rval; return (PyObject *) rval;
} }
__global__ void k_vector_add_fast(int numRowsX,
int numColsX,
int stridesX0,
int stridesX1,
float *X,
int numRowsY,
int numColsY,
int stridesY0,
int stridesY1,
float *Y ,
long *d_indices_arr,
int num)
{
for (int i = (blockIdx.x); i < num; i += gridDim.x)
{
for(int j = (threadIdx.x); j < numColsX;j += blockDim.x)
{
int x_row = d_indices_arr[i];
int y_row = i;
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)], Y[(y_row * stridesY0) + (j * stridesY1)]);
}
}
return;
}
template <int operator_num> template <int operator_num>
__global__ void k_ielem_3(const int d0, const int d1, const int d2, __global__ void k_ielem_3(const int d0, const int d1, const int d2,
float* a, const int sA0, const int sA1, const int sA2, float* a, const int sA0, const int sA1, const int sA2,
...@@ -1802,56 +1776,6 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1802,56 +1776,6 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
return 0; return 0;
} }
void CudaNdarray_vector_add_fast(CudaNdarray* py_self, CudaNdarray* py_other, PyArrayObject *indices_arr)
{
const int *shapeX = CudaNdarray_HOST_DIMS(py_self);
const int *shapeY = CudaNdarray_HOST_DIMS(py_other);
const int *strX = CudaNdarray_HOST_STRIDES(py_self);
const int *strY = CudaNdarray_HOST_STRIDES(py_other);
unsigned int size = (unsigned int)PyArray_SIZE(indices_arr);
unsigned int numcolsX = shapeX[1];
unsigned int num_threads_per_block = std::min(numcolsX, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
unsigned int num_blocks = std::min(size ,(unsigned int)NUM_VECTOR_OP_BLOCKS);
dim3 n_blocks(num_blocks);
dim3 n_threads(num_threads_per_block);
long *d_indices_arr = NULL;
long *cpu_indices_arr = (long*)malloc(sizeof(long) * PyArray_SIZE(indices_arr));
assert(cpu_indices_arr);
d_indices_arr = (long *)device_malloc(sizeof(long) * PyArray_SIZE(indices_arr));
assert(d_indices_arr);
for (int j = 0; j < size; j++)
{
long *el = (long*)PyArray_GETPTR1(indices_arr, j);
cpu_indices_arr[j] = *el;
}
cudaError_t err = cudaMemcpy(d_indices_arr,
cpu_indices_arr,
sizeof(long) * size,
cudaMemcpyHostToDevice);
k_vector_add_fast<<<n_blocks, n_threads>>>(shapeX[0],
shapeX[1],
strX[0],
strX[1],
CudaNdarray_DEV_DATA(py_self),
shapeY[0],
shapeY[1],
strY[0],
strY[1],
CudaNdarray_DEV_DATA(py_other),
d_indices_arr,
PyArray_SIZE(indices_arr)
);
device_free(d_indices_arr);
return;
}
/* /*
* We need this inplace Add to support IncSubTensor * We need this inplace Add to support IncSubTensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论