提交 efa8dbc9 authored 作者: notoraptor's avatar notoraptor

Back to `GpuArray_split()`

+ some other code updates.
上级 0f4192fb
......@@ -1395,7 +1395,7 @@ class GpuSplit(HideC, Split):
# we reuse the perform of the CPU op, which is suitable
def c_code_cache_version(self):
return (3,)
return (1,)
def c_headers(self):
return ['<numpy_compat.h>', '<gpuarray_helper.h>']
......@@ -1403,32 +1403,30 @@ class GpuSplit(HideC, Split):
def c_header_dirs(self):
return [pygpu.get_include(), os.path.dirname(__file__)]
def c_support_code(self):
# We want to hide the definition of this method from tensor.Split.
return ''
def c_code(self, node, name, inputs, outputs, sub):
if self.len_splits == 0:
# There are no outputs, then nothing to do.
return ''
x, axis, splits = inputs
fail = sub['fail']
x_typecode = pygpu.gpuarray.dtype_to_typecode(node.inputs[0].dtype)
splits_dtype = node.inputs[2].type.dtype_specs()[1]
axis_dtype = node.inputs[1].type.dtype_specs()[1]
expected_splits_count = self.len_splits
codes_for_checking_outputs = [] # filled later.
code_for_splitting = [] # filled later.
code_for_updating_output = [] # filled later.
code_if_sync = [] # filled later.
full_code_for_checking_outputs = '' # defined later.
full_code_for_splitting = '' # defined later.
full_code_for_updating_output = '' # defined later.
full_code_if_sync = '' # defined later.
main_code = """
ga_order x_order = (%(x)s->ga.flags & GA_C_ORDER) ? GA_C_ORDER : GA_F_ORDER;
int ndim = PyGpuArray_NDIM(%(x)s);
int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
int splits_count = PyArray_SIZE(%(splits)s);
size_t len_along_axis, sum_of_splits = 0, current_split_start = 0;
%(splits_dtype)s current_split_length = 0;
GpuArray view;
int splits_count = PyArray_DIM(%(splits)s, 0);
size_t len_along_axis, sum_of_splits = 0;
%(splits_dtype)s current_split_length;
size_t* split_points = NULL;
GpuArray* split_views = NULL;
GpuArray** split_views_pointers = NULL;
int i;
/* Check inputs. */
......@@ -1460,54 +1458,87 @@ class GpuSplit(HideC, Split):
%(fail)s
}
/* Check outputs. */
%(full_code_for_checking_outputs)s
/* Compute splits views. */
/* Compute splits. */
if (GpuArray_view(&view, &%(x)s->ga) != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to create a view of the input.");
split_points = (size_t*) malloc((splits_count - 1) * sizeof(size_t));
if (split_points == NULL) {
PyErr_NoMemory();
%(fail)s
}
split_points[0] = (size_t) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, 0) );
for(i = 1; i < splits_count - 1; ++i) {
split_points[i] = split_points[i - 1] + (size_t) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, i) );
}
split_views = (GpuArray*) malloc(splits_count * sizeof(GpuArray));
split_views_pointers = (GpuArray**) malloc(splits_count * sizeof(GpuArray*));
if (split_views == NULL || split_views_pointers == NULL) {
PyErr_NoMemory();
free(split_views_pointers);
free(split_views);
free(split_points);
%(fail)s
}
for (i = 0; i < splits_count; ++i) {
split_views_pointers[i] = split_views + i;
}
if (GpuArray_split(split_views_pointers, &%(x)s->ga, splits_count - 1, split_points, axis) != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to compute split.");
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
%(fail)s
}
%(full_code_for_splitting)s
/* Put split views into outputs. */
%(full_code_for_updating_output)s
/* Free memory. */
GpuArray_clear(&view);
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
/* Code added if synchronization is enabled. */
%(full_code_if_sync)s
"""
for split_index, output in enumerate(outputs):
# When checking output, we allocate a PyGpuArrayObject with 0 dims
# (that is, an object with as few memory as possible), as its GpuArray field
# will be cleared and re-used as a view during split operations.
codes_for_checking_outputs.append("""
if (theano_prep_output(&%(output)s, 0, NULL, %(x_typecode)s, x_order, %(x)s->context) != 0) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to prepare an output.");
%(fail)s
}
""" % locals())
code_for_splitting.append("""
current_split_length = * (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, %(split_index)s);
view.offset = PyGpuArray_STRIDE(%(x)s, axis) * current_split_start;
view.dimensions[axis] = current_split_length;
GpuArray_fix_flags(&view);
GpuArray_clear(&%(output)s->ga);
if (GpuArray_view(&%(output)s->ga, &view) != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to transfer a view into an output.");
GpuArray_clear(&view);
code_for_updating_output.append("""
Py_XDECREF(%(output)s);
%(output)s = pygpu_fromgpudata(
split_views[%(split_index)s].data,
split_views[%(split_index)s].offset,
split_views[%(split_index)s].typecode,
split_views[%(split_index)s].nd,
split_views[%(split_index)s].dimensions,
split_views[%(split_index)s].strides,
%(x)s->context,
1, // output is writable
Py_None, Py_None
);
if (%(output)s == NULL) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to update an output from a split view.");
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
%(fail)s
};
current_split_start += current_split_length;
}
""" % locals())
if config.gpuarray.sync:
code_if_sync.append("GpuArray_sync(&%(output)s->ga);" % locals())
full_code_for_checking_outputs = '\r\n'.join(codes_for_checking_outputs)
full_code_for_splitting = '\r\n'.join(code_for_splitting)
full_code_if_sync = '\r\n'.join(code_if_sync)
full_code_for_updating_output = '\n'.join(code_for_updating_output)
full_code_if_sync = '\n'.join(code_if_sync)
return main_code % locals()
......
......@@ -3779,7 +3779,7 @@ class Split(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs
def c_code_cache_version(self):
return (2,)
return (1,)
def c_support_code(self):
return """
......@@ -3787,25 +3787,27 @@ class Split(Op):
int split_output_shape_is_correct (
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
) {
if (PyArray_NDIM(output) == PyArray_NDIM(array_to_split)) {
int i;
for (i = 0; i < axis_to_split; ++i) {
if (PyArray_DIM(output, i) != PyArray_DIM(array_to_split, i)) {
return 0;
}
}
for (i = axis_to_split + 1; i < PyArray_NDIM(array_to_split); ++i) {
if (PyArray_DIM(output, i) != PyArray_DIM(array_to_split, i)) {
return 0;
}
}
return split_size == PyArray_DIM(output, axis_to_split);
}
return 0;
return
PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
&& memcmp(
PyArray_DIMS(output),
PyArray_DIMS(array_to_split),
axis_to_split * sizeof(npy_intp)
) == 0
&& memcmp(
PyArray_DIMS(output) + axis_to_split + 1,
PyArray_DIMS(array_to_split) + axis_to_split + 1,
(PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
) == 0
&& split_size == PyArray_DIM(output, axis_to_split);
}
"""
def c_code(self, node, name, inputs, outputs, sub):
if self.len_splits == 0:
# There are no outputs, then nothing to do.
return ''
x, axis, splits = inputs
fail = sub['fail']
x_typenum = numpy.dtype(node.inputs[0].dtype).num
......@@ -3821,7 +3823,7 @@ class Split(Op):
main_code = """
int ndim = PyArray_NDIM(%(x)s);
int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
int splits_count = PyArray_SIZE(%(splits)s);
int splits_count = PyArray_DIM(%(splits)s, 0);
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
npy_intp* split_dims = NULL;
PyObject* split_view = NULL;
......@@ -3862,6 +3864,11 @@ class Split(Op):
/* Check outputs. */
split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
if (split_dims == NULL) {
PyErr_NoMemory();
%(fail)s
}
memcpy(split_dims, PyArray_DIMS(%(x)s), ndim * sizeof(npy_intp));
%(full_code_for_checking_outputs)s
......@@ -3916,8 +3923,8 @@ class Split(Op):
current_split_start += current_split_length;
""" % locals())
full_code_for_checking_outputs = '\r\n'.join(codes_for_checking_outputs)
full_code_for_splitting = '\r\n'.join(codes_for_splitting)
full_code_for_checking_outputs = '\n'.join(codes_for_checking_outputs)
full_code_for_splitting = '\n'.join(codes_for_splitting)
return main_code % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论