提交 82d5a37f authored 作者: notoraptor's avatar notoraptor

Put outputs variables into a C vector and use C loops instead of Python-generated repeated C code.

Done in both CPU and GPU code.
上级 efa8dbc9
......@@ -1408,15 +1408,13 @@ class GpuSplit(HideC, Split):
# There are no outputs, then nothing to do.
return ''
# outputs_pointers lists the addresses of the pointers to the outputs.
outputs_pointers = '&' + (', &'.join(outputs))
x, axis, splits = inputs
fail = sub['fail']
splits_dtype = node.inputs[2].type.dtype_specs()[1]
axis_dtype = node.inputs[1].type.dtype_specs()[1]
expected_splits_count = self.len_splits
code_for_updating_output = [] # filled later.
code_if_sync = [] # filled later.
full_code_for_updating_output = '' # defined later.
full_code_if_sync = '' # defined later.
main_code = """
int ndim = PyGpuArray_NDIM(%(x)s);
......@@ -1427,7 +1425,8 @@ class GpuSplit(HideC, Split):
size_t* split_points = NULL;
GpuArray* split_views = NULL;
GpuArray** split_views_pointers = NULL;
int i;
int i, j;
PyGpuArrayObject** outputs[] = {%(outputs_pointers)s};
/* Check inputs. */
......@@ -1493,52 +1492,47 @@ class GpuSplit(HideC, Split):
}
/* Put split views into outputs. */
%(full_code_for_updating_output)s
/* Free memory. */
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
/* Code added if synchronization is enabled. */
%(full_code_if_sync)s
"""
for split_index, output in enumerate(outputs):
code_for_updating_output.append("""
Py_XDECREF(%(output)s);
%(output)s = pygpu_fromgpudata(
split_views[%(split_index)s].data,
split_views[%(split_index)s].offset,
split_views[%(split_index)s].typecode,
split_views[%(split_index)s].nd,
split_views[%(split_index)s].dimensions,
split_views[%(split_index)s].strides,
PyGpuArrayObject** output = outputs[i];
Py_XDECREF(*output);
*output = pygpu_fromgpudata(
split_views[i].data,
split_views[i].offset,
split_views[i].typecode,
split_views[i].nd,
split_views[i].dimensions,
split_views[i].strides,
%(x)s->context,
1, // output is writable
Py_None, Py_None
);
if (%(output)s == NULL) {
if (*output == NULL) {
PyErr_SetString(PyExc_RuntimeError, "GpuSplit: unable to update an output from a split view.");
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
for (j = 0; j < splits_count; ++j) {
GpuArray_clear(split_views_pointers[j]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
%(fail)s
}
""" % locals())
}
if config.gpuarray.sync:
code_if_sync.append("GpuArray_sync(&%(output)s->ga);" % locals())
/* Free memory. */
for (i = 0; i < splits_count; ++i) {
GpuArray_clear(split_views_pointers[i]);
}
free(split_views_pointers);
free(split_views);
free(split_points);
"""
full_code_for_updating_output = '\n'.join(code_for_updating_output)
full_code_if_sync = '\n'.join(code_if_sync)
if config.gpuarray.sync:
main_code += """
for (i = 0; i < splits_count; ++i) {
GpuArray_sync(&((*outputs[i])->ga));
}
"""
return main_code % locals()
......
......@@ -3808,6 +3808,8 @@ class Split(Op):
# There are no outputs, then nothing to do.
return ''
# outputs_pointers lists the addresses of the pointers to the outputs.
outputs_pointers = '&' + (', &'.join(outputs))
x, axis, splits = inputs
fail = sub['fail']
x_typenum = numpy.dtype(node.inputs[0].dtype).num
......@@ -3815,12 +3817,8 @@ class Split(Op):
axis_dtype = node.inputs[1].type.dtype_specs()[1]
splits_dtype = node.inputs[2].type.dtype_specs()[1]
expected_splits_count = self.len_splits
codes_for_checking_outputs = [] # filled later.
codes_for_splitting = [] # filled later.
full_code_for_checking_outputs = '' # defined later.
full_code_for_splitting = '' # defined later.
main_code = """
return """
int ndim = PyArray_NDIM(%(x)s);
int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
int splits_count = PyArray_DIM(%(splits)s, 0);
......@@ -3829,6 +3827,7 @@ class Split(Op):
PyObject* split_view = NULL;
npy_intp data_offset;
int i;
PyArrayObject** outputs[] = {%(outputs_pointers)s};
/* Check inputs. */
......@@ -3871,33 +3870,25 @@ class Split(Op):
memcpy(split_dims, PyArray_DIMS(%(x)s), ndim * sizeof(npy_intp));
%(full_code_for_checking_outputs)s
/* Compute split. */
%(full_code_for_splitting)s
free(split_dims);
"""
for split_index, output in enumerate(outputs):
codes_for_checking_outputs.append("""
current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, %(split_index)s));
if (%(output)s == NULL || !split_output_shape_is_correct(%(output)s, %(x)s, axis, current_split_length)) {
Py_XDECREF(%(output)s);
for (i = 0; i < splits_count; ++i) {
PyArrayObject** output = outputs[i];
current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, i));
if (*output == NULL || !split_output_shape_is_correct(*output, %(x)s, axis, current_split_length)) {
Py_XDECREF(*output);
split_dims[axis] = current_split_length;
%(output)s = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, %(x_typenum)s, PyArray_IS_F_CONTIGUOUS(%(x)s));
if (%(output)s == NULL) {
*output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, %(x_typenum)s, PyArray_IS_F_CONTIGUOUS(%(x)s));
if (outputs == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
free(split_dims);
%(fail)s
}
}
""" % locals())
}
codes_for_splitting.append("""
current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, %(split_index)s));
/* Compute split. */
for (i = 0; i < splits_count; ++i) {
current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, i));
data_offset = PyArray_STRIDE(%(x)s, axis) * current_split_start;
split_dims[axis] = current_split_length;
split_view = PyArray_New(&PyArray_Type,
......@@ -3913,7 +3904,7 @@ class Split(Op):
free(split_dims);
%(fail)s
}
if (PyArray_CopyInto(%(output)s, (PyArrayObject*)split_view) != 0) {
if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {
PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
Py_XDECREF(split_view);
free(split_dims);
......@@ -3921,12 +3912,10 @@ class Split(Op):
}
Py_XDECREF(split_view);
current_split_start += current_split_length;
""" % locals())
full_code_for_checking_outputs = '\n'.join(codes_for_checking_outputs)
full_code_for_splitting = '\n'.join(codes_for_splitting)
}
return main_code % locals()
free(split_dims);
""" % locals()
def addbroadcast(x, *axes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论