提交 b77115c3 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3790 from hidasib/gpu_set_subtensor_2d

set_subtensor GPU realizations
...@@ -2710,7 +2710,7 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp): ...@@ -2710,7 +2710,7 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
raise TypeError('cannot index into a scalar') raise TypeError('cannot index into a scalar')
# c code suppose it is int64 # c code suppose it is int64
if x.ndim in [2, 3] and ilist_.dtype in [ if x.ndim in [1, 2, 3] and ilist_.dtype in [
'int8', 'int16', 'int32', 'uint8', 'uint16', 'uint32']: 'int8', 'int16', 'int32', 'uint8', 'uint16', 'uint32']:
ilist_ = tensor.cast(ilist_, 'int64') ilist_ = tensor.cast(ilist_, 'int64')
...@@ -2776,7 +2776,7 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp): ...@@ -2776,7 +2776,7 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
x, idx = inputs x, idx = inputs
out, = outputs out, = outputs
fail = sub['fail'] fail = sub['fail']
if node.inputs[0].ndim not in [2, 3]: if node.inputs[0].ndim not in [1, 2, 3]:
raise NotImplementedError("This case does not have C code yet.") raise NotImplementedError("This case does not have C code yet.")
if node.inputs[1].dtype != 'int64': if node.inputs[1].dtype != 'int64':
raise Exception("Index should have dtype int64. Check this node make_node().") raise Exception("Index should have dtype int64. Check this node make_node().")
...@@ -2888,11 +2888,10 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2888,11 +2888,10 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
out[0] = x out[0] = x
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (7,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if (self.set_instead_of_inc) or \ if (node.inputs[0].ndim != node.inputs[1].ndim):
(node.inputs[0].ndim != node.inputs[1].ndim):
raise NotImplementedError("This case does not have C code yet.") raise NotImplementedError("This case does not have C code yet.")
x = inputs[0] x = inputs[0]
...@@ -2901,6 +2900,7 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2901,6 +2900,7 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
out = outputs[0] out = outputs[0]
fail = sub['fail'] fail = sub['fail']
inplace = int(self.inplace) inplace = int(self.inplace)
set_instead_of_inc = int(self.set_instead_of_inc)
return """ return """
PyObject *row_x, *row_y; PyObject *row_x, *row_y;
...@@ -2960,8 +2960,11 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2960,8 +2960,11 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
Py_XDECREF(x_rowind_obj); Py_XDECREF(x_rowind_obj);
%(fail)s; %(fail)s;
} }
if (%(set_instead_of_inc)s) {
ret = CudaNdarray_inplace_elemwise(row_x, row_y, IADD); ret = CudaNdarray_CopyFromCudaNdarray((CudaNdarray *) row_x, (CudaNdarray *) row_y);
} else {
ret = CudaNdarray_inplace_elemwise(row_x, row_y, IADD);
}
if (ret != 0) { if (ret != 0) {
Py_XDECREF(row_y); Py_XDECREF(row_y);
Py_XDECREF(row_x); Py_XDECREF(row_x);
...@@ -3030,13 +3033,12 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3030,13 +3033,12 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
return Apply(self, [x_, y_, ilist_], [x_.type()]) return Apply(self, [x_, y_, ilist_], [x_.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (7,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
active_device_no = theano.sandbox.cuda.active_device_number() active_device_no = theano.sandbox.cuda.active_device_number()
compute_capability = device_properties(active_device_no)['major'] compute_capability = device_properties(active_device_no)['major']
if ((self.set_instead_of_inc) or if ((node.inputs[0].ndim != node.inputs[1].ndim) or
(node.inputs[0].ndim != node.inputs[1].ndim) or
(node.inputs[0].ndim != 2) or (node.inputs[0].ndim != 2) or
(compute_capability < 2)): (compute_capability < 2)):
raise NotImplementedError("This case does not have C code yet.") raise NotImplementedError("This case does not have C code yet.")
...@@ -3047,6 +3049,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3047,6 +3049,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
out = outputs[0] out = outputs[0]
fail = sub['fail'] fail = sub['fail']
inplace = int(self.inplace) inplace = int(self.inplace)
set_instead_of_inc = int(self.set_instead_of_inc)
return """ return """
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
if (!%(inplace)s) { if (!%(inplace)s) {
...@@ -3056,7 +3059,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3056,7 +3059,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
Py_XINCREF(%(out)s); Py_XINCREF(%(out)s);
} }
if (CudaNdarray_vector_add_fast(%(out)s, %(y)s, %(ind)s) != 0){ if (CudaNdarray_vector_add_or_replace_fast(%(out)s, %(y)s, %(ind)s, %(set_instead_of_inc)s) != 0){
%(fail)s %(fail)s
} }
...@@ -3068,7 +3071,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3068,7 +3071,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
return """ return """
__global__ void k_vector_add_fast(int numRowsX, __global__ void k_vector_add_or_replace_fast(int numRowsX,
int numColsX, int numColsX,
int stridesX0, int stridesX0,
int stridesX1, int stridesX1,
...@@ -3080,6 +3083,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3080,6 +3083,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
float *Y , float *Y ,
long *d_indices_arr, long *d_indices_arr,
int num, int num,
const int set_instead_of_inc,
int* err) int* err)
{ {
for (int i = (blockIdx.x); i < num; i += gridDim.x) for (int i = (blockIdx.x); i < num; i += gridDim.x)
...@@ -3091,8 +3095,13 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3091,8 +3095,13 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
x_row += numRowsX; x_row += numRowsX;
int y_row = i; int y_row = i;
if(x_row < numRowsX && x_row >= 0){ if(x_row < numRowsX && x_row >= 0){
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)], if(set_instead_of_inc){
atomicExch(&X[(x_row * stridesX0) + (j * stridesX1)],
Y[(y_row * stridesY0) + (j * stridesY1)]);
} else{
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)],
Y[(y_row * stridesY0) + (j * stridesY1)]); Y[(y_row * stridesY0) + (j * stridesY1)]);
}
} else { } else {
*err = 1; *err = 1;
} }
...@@ -3101,8 +3110,9 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3101,8 +3110,9 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
return; return;
} }
int CudaNdarray_vector_add_fast(CudaNdarray* py_self, int CudaNdarray_vector_add_or_replace_fast(CudaNdarray* py_self,
CudaNdarray* py_other, PyArrayObject *indices_arr) CudaNdarray* py_other, PyArrayObject *indices_arr,
const int set_instead_of_inc)
{ {
if(init_err_var()!= 0) return -1; if(init_err_var()!= 0) return -1;
...@@ -3144,7 +3154,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3144,7 +3154,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
return -1; return -1;
} }
k_vector_add_fast<<<n_blocks, n_threads>>>( k_vector_add_or_replace_fast<<<n_blocks, n_threads>>>(
shapeX[0], shapeX[0],
shapeX[1], shapeX[1],
strX[0], strX[0],
...@@ -3157,6 +3167,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -3157,6 +3167,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
CudaNdarray_DEV_DATA(py_other), CudaNdarray_DEV_DATA(py_other),
d_indices_arr, d_indices_arr,
PyArray_SIZE(indices_arr), PyArray_SIZE(indices_arr),
set_instead_of_inc,
err_var err_var
); );
int index_err = check_err_var(); int index_err = check_err_var();
......
...@@ -1108,6 +1108,41 @@ def test_advinc_subtensor1(): ...@@ -1108,6 +1108,41 @@ def test_advinc_subtensor1():
rep[[0, 2]] += yval rep[[0, 2]] += yval
utt.assert_allclose(rval, rep) utt.assert_allclose(rval, rep)
def test_advset_subtensor1():
""" Test GPU version of set_subtensor on vectors (uses GpuAdvancedIncSubtensor1) """
shp = (10,)
shared = cuda.shared_constructor
xval = numpy.arange(shp[0], dtype='float32').reshape(shp) + 1
idxs = numpy.array([0,2,5,7,3], dtype='int32')
yval = numpy.ones(len(idxs), dtype='float32')*10
x = shared(xval, name='x')
y = T.tensor(dtype='float32', broadcastable=(False,) * len(shp), name='y')
expr = T.advanced_set_subtensor1(x, y, idxs)
f = theano.function([y], expr, mode=mode_with_gpu)
assert sum([isinstance(node.op, cuda.GpuAdvancedIncSubtensor1)
for node in f.maker.fgraph.toposort()]) == 1
rval = f(yval)
rep = xval.copy()
rep[idxs] = yval
utt.assert_allclose(rval, rep)
def test_advset_subtensor1_2d():
""" Test GPU version of set_subtensor on matrices (uses GpuAdvancedIncSubtensor1_dev20 if compute capability >= 2.0) """
shp = (10,5)
shared = cuda.shared_constructor
xval = numpy.arange(numpy.prod(shp), dtype='float32').reshape(shp) + 1
idxs = numpy.array([0,2,5,7,3], dtype='int32')
yval = numpy.ones((len(idxs), shp[1]), dtype='float32')*10
x = shared(xval, name='x')
y = T.tensor(dtype='float32', broadcastable=(False,) * len(shp), name='y')
expr = T.advanced_set_subtensor1(x, y, idxs)
f = theano.function([y], expr, mode=mode_with_gpu)
assert sum([isinstance(node.op, cuda.GpuAdvancedIncSubtensor1)
for node in f.maker.fgraph.toposort()]) == 1
rval = f(yval)
rep = xval.copy()
rep[idxs] = yval
utt.assert_allclose(rval, rep)
def test_inc_subtensor(): def test_inc_subtensor():
shared = cuda.shared_constructor shared = cuda.shared_constructor
...@@ -1341,5 +1376,7 @@ def speed_reduce10(): ...@@ -1341,5 +1376,7 @@ def speed_reduce10():
if __name__ == '__main__': if __name__ == '__main__':
test_many_arg_elemwise() #test_many_arg_elemwise()
test_gpujoin_assert_cndas() #test_gpujoin_assert_cndas()
test_advset_subtensor1()
test_advset_subtensor1_2d()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论