提交 efa8dbc9 authored 作者: notoraptor's avatar notoraptor

Back to `GpuArray_split()`

+ some other code updates.
上级 0f4192fb
...@@ -1395,7 +1395,7 @@ class GpuSplit(HideC, Split): ...@@ -1395,7 +1395,7 @@ class GpuSplit(HideC, Split):
# we reuse the perform of the CPU op, which is suitable # we reuse the perform of the CPU op, which is suitable
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (1,)
def c_headers(self): def c_headers(self):
return ['<numpy_compat.h>', '<gpuarray_helper.h>'] return ['<numpy_compat.h>', '<gpuarray_helper.h>']
...@@ -1403,32 +1403,30 @@ class GpuSplit(HideC, Split): ...@@ -1403,32 +1403,30 @@ class GpuSplit(HideC, Split):
def c_header_dirs(self): def c_header_dirs(self):
return [pygpu.get_include(), os.path.dirname(__file__)] return [pygpu.get_include(), os.path.dirname(__file__)]
def c_support_code(self):
# We want to hide the definition of this method from tensor.Split.
return ''
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if self.len_splits == 0:
# There are no outputs, then nothing to do.
return ''
x, axis, splits = inputs x, axis, splits = inputs
fail = sub['fail'] fail = sub['fail']
x_typecode = pygpu.gpuarray.dtype_to_typecode(node.inputs[0].dtype)
splits_dtype = node.inputs[2].type.dtype_specs()[1] splits_dtype = node.inputs[2].type.dtype_specs()[1]
axis_dtype = node.inputs[1].type.dtype_specs()[1] axis_dtype = node.inputs[1].type.dtype_specs()[1]
expected_splits_count = self.len_splits expected_splits_count = self.len_splits
codes_for_checking_outputs = [] # filled later. code_for_updating_output = [] # filled later.
code_for_splitting = [] # filled later.
code_if_sync = [] # filled later. code_if_sync = [] # filled later.
full_code_for_checking_outputs = '' # defined later. full_code_for_updating_output = '' # defined later.
full_code_for_splitting = '' # defined later.
full_code_if_sync = '' # defined later. full_code_if_sync = '' # defined later.
main_code = """ main_code = """
ga_order x_order = (%(x)s->ga.flags & GA_C_ORDER) ? GA_C_ORDER : GA_F_ORDER;
int ndim = PyGpuArray_NDIM(%(x)s); int ndim = PyGpuArray_NDIM(%(x)s);
int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0)); int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
int splits_count = PyArray_SIZE(%(splits)s); int splits_count = PyArray_DIM(%(splits)s, 0);
size_t len_along_axis, sum_of_splits = 0, current_split_start = 0; size_t len_along_axis, sum_of_splits = 0;
%(splits_dtype)s current_split_length = 0; %(splits_dtype)s current_split_length;
GpuArray view; size_t* split_points = NULL;
GpuArray* split_views = NULL;
GpuArray** split_views_pointers = NULL;
int i; int i;
/* Check inputs. */ /* Check inputs. */
...@@ -1460,54 +1458,87 @@ class GpuSplit(HideC, Split): ...@@ -1460,54 +1458,87 @@ class GpuSplit(HideC, Split):
%(fail)s %(fail)s
} }
/* Check outputs. */ /* Compute splits views. */
%(full_code_for_checking_outputs)s
/* Compute splits. */ split_points = (size_t*) malloc((splits_count - 1) * sizeof(size_t));
if (GpuArray_view(&view, &%(x)s->ga) != GA_NO_ERROR) { if (split_points == NULL) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to create a view of the input."); PyErr_NoMemory();
%(fail)s
}
split_points[0] = (size_t) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, 0) );
for(i = 1; i < splits_count - 1; ++i) {
split_points[i] = split_points[i - 1] + (size_t) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, i) );
}
split_views = (GpuArray*) malloc(splits_count * sizeof(GpuArray));
split_views_pointers = (GpuArray**) malloc(splits_count * sizeof(GpuArray*));
if (split_views == NULL || split_views_pointers == NULL) {
PyErr_NoMemory();
free(split_views_pointers);
free(split_views);
free(split_points);
%(fail)s
}
for (i = 0; i < splits_count; ++i) {
split_views_pointers[i] = split_views + i;
}
if (GpuArray_split(split_views_pointers, &%(x)s->ga, splits_count - 1, split_points, axis) != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to compute split.");
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
%(fail)s %(fail)s
} }
%(full_code_for_splitting)s
/* Put split views into outputs. */
%(full_code_for_updating_output)s
/* Free memory. */ /* Free memory. */
GpuArray_clear(&view); for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
/* Code added if synchronization is enabled. */ /* Code added if synchronization is enabled. */
%(full_code_if_sync)s %(full_code_if_sync)s
""" """
for split_index, output in enumerate(outputs): for split_index, output in enumerate(outputs):
# When checking output, we allocate a PyGpuArrayObject with 0 dims
# (that is, an object with as few memory as possible), as its GpuArray field
# will be cleared and re-used as a view during split operations.
codes_for_checking_outputs.append("""
if (theano_prep_output(&%(output)s, 0, NULL, %(x_typecode)s, x_order, %(x)s->context) != 0) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to prepare an output.");
%(fail)s
}
""" % locals())
code_for_splitting.append(""" code_for_updating_output.append("""
current_split_length = * (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, %(split_index)s); Py_XDECREF(%(output)s);
view.offset = PyGpuArray_STRIDE(%(x)s, axis) * current_split_start; %(output)s = pygpu_fromgpudata(
view.dimensions[axis] = current_split_length; split_views[%(split_index)s].data,
GpuArray_fix_flags(&view); split_views[%(split_index)s].offset,
GpuArray_clear(&%(output)s->ga); split_views[%(split_index)s].typecode,
if (GpuArray_view(&%(output)s->ga, &view) != GA_NO_ERROR) { split_views[%(split_index)s].nd,
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to transfer a view into an output."); split_views[%(split_index)s].dimensions,
GpuArray_clear(&view); split_views[%(split_index)s].strides,
%(x)s->context,
1, // output is writable
Py_None, Py_None
);
if (%(output)s == NULL) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to update an output from a split view.");
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
%(fail)s %(fail)s
}; }
current_split_start += current_split_length;
""" % locals()) """ % locals())
if config.gpuarray.sync: if config.gpuarray.sync:
code_if_sync.append("GpuArray_sync(&%(output)s->ga);" % locals()) code_if_sync.append("GpuArray_sync(&%(output)s->ga);" % locals())
full_code_for_checking_outputs = '\r\n'.join(codes_for_checking_outputs) full_code_for_updating_output = '\n'.join(code_for_updating_output)
full_code_for_splitting = '\r\n'.join(code_for_splitting) full_code_if_sync = '\n'.join(code_if_sync)
full_code_if_sync = '\r\n'.join(code_if_sync)
return main_code % locals() return main_code % locals()
......
...@@ -3779,7 +3779,7 @@ class Split(Op): ...@@ -3779,7 +3779,7 @@ class Split(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (1,)
def c_support_code(self): def c_support_code(self):
return """ return """
...@@ -3787,25 +3787,27 @@ class Split(Op): ...@@ -3787,25 +3787,27 @@ class Split(Op):
int split_output_shape_is_correct ( int split_output_shape_is_correct (
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
) { ) {
if (PyArray_NDIM(output) == PyArray_NDIM(array_to_split)) { return
int i; PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
for (i = 0; i < axis_to_split; ++i) { && memcmp(
if (PyArray_DIM(output, i) != PyArray_DIM(array_to_split, i)) { PyArray_DIMS(output),
return 0; PyArray_DIMS(array_to_split),
} axis_to_split * sizeof(npy_intp)
} ) == 0
for (i = axis_to_split + 1; i < PyArray_NDIM(array_to_split); ++i) { && memcmp(
if (PyArray_DIM(output, i) != PyArray_DIM(array_to_split, i)) { PyArray_DIMS(output) + axis_to_split + 1,
return 0; PyArray_DIMS(array_to_split) + axis_to_split + 1,
} (PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
} ) == 0
return split_size == PyArray_DIM(output, axis_to_split); && split_size == PyArray_DIM(output, axis_to_split);
}
return 0;
} }
""" """
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if self.len_splits == 0:
# There are no outputs, then nothing to do.
return ''
x, axis, splits = inputs x, axis, splits = inputs
fail = sub['fail'] fail = sub['fail']
x_typenum = numpy.dtype(node.inputs[0].dtype).num x_typenum = numpy.dtype(node.inputs[0].dtype).num
...@@ -3821,7 +3823,7 @@ class Split(Op): ...@@ -3821,7 +3823,7 @@ class Split(Op):
main_code = """ main_code = """
int ndim = PyArray_NDIM(%(x)s); int ndim = PyArray_NDIM(%(x)s);
int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0)); int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
int splits_count = PyArray_SIZE(%(splits)s); int splits_count = PyArray_DIM(%(splits)s, 0);
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0; npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
npy_intp* split_dims = NULL; npy_intp* split_dims = NULL;
PyObject* split_view = NULL; PyObject* split_view = NULL;
...@@ -3862,6 +3864,11 @@ class Split(Op): ...@@ -3862,6 +3864,11 @@ class Split(Op):
/* Check outputs. */ /* Check outputs. */
split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp)); split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
if (split_dims == NULL) {
PyErr_NoMemory();
%(fail)s
}
memcpy(split_dims, PyArray_DIMS(%(x)s), ndim * sizeof(npy_intp)); memcpy(split_dims, PyArray_DIMS(%(x)s), ndim * sizeof(npy_intp));
%(full_code_for_checking_outputs)s %(full_code_for_checking_outputs)s
...@@ -3916,8 +3923,8 @@ class Split(Op): ...@@ -3916,8 +3923,8 @@ class Split(Op):
current_split_start += current_split_length; current_split_start += current_split_length;
""" % locals()) """ % locals())
full_code_for_checking_outputs = '\r\n'.join(codes_for_checking_outputs) full_code_for_checking_outputs = '\n'.join(codes_for_checking_outputs)
full_code_for_splitting = '\r\n'.join(codes_for_splitting) full_code_for_splitting = '\n'.join(codes_for_splitting)
return main_code % locals() return main_code % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论