提交 ae3abc3d authored 作者: notoraptor's avatar notoraptor

Add C code for CPU for tensor.Split.

Method: for each split, a view is created and then copied into an output. To come up: * Add C code for GPU. * Check if `self.len_splits` can be defined as a param op.
上级 4e29b2f7
......@@ -3778,6 +3778,149 @@ class Split(Op):
return [None for i in self.len_splits]
return self.make_node(eval_points[0], *inputs[1:]).outputs
def c_code_cache_version(self):
return (1,)
def c_support_code(self):
return """
/* Return 1 if output has the correct shape. */
int split_output_shape_is_correct (
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
) {
if (PyArray_NDIM(output) == PyArray_NDIM(array_to_split)) {
int i;
for (i = 0; i < axis_to_split; ++i) {
if (PyArray_DIM(output, i) != PyArray_DIM(array_to_split, i)) {
return 0;
}
}
for (i = axis_to_split + 1; i < PyArray_NDIM(array_to_split); ++i) {
if (PyArray_DIM(output, i) != PyArray_DIM(array_to_split, i)) {
return 0;
}
}
return split_size == PyArray_DIM(output, axis_to_split);
}
return 0;
}
"""
def c_code(self, node, name, inputs, outputs, sub):
x, axis, splits = inputs
fail = sub['fail']
x_typenum = numpy.dtype(node.inputs[0].dtype).num
x_itemsize = numpy.dtype(node.inputs[0].dtype).itemsize
axis_dtype = node.inputs[1].type.dtype_specs()[1]
splits_dtype = node.inputs[2].type.dtype_specs()[1]
expected_splits_count = self.len_splits
codes_for_checking_outputs = [] # filled later.
codes_for_splitting = [] # filled later.
full_code_for_checking_outputs = '' # defined later.
full_code_for_splitting = '' # defined later.
main_code = """
int ndim = PyArray_NDIM(%(x)s);
int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
int splits_count = PyArray_SIZE(%(splits)s);
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
npy_intp* split_dims = NULL;
PyObject* split_view = NULL;
npy_intp data_offset;
int i;
/* Check inputs. */
if (splits_count != %(expected_splits_count)s) {
PyErr_Format(PyExc_ValueError,
"Split: splits count (%%d) != expected count (%%d).", splits_count, %(expected_splits_count)s);
%(fail)s
}
if (axis < 0) {
axis += ndim;
}
if (axis < 0 || axis >= ndim) {
PyErr_Format(PyExc_IndexError, "Split: invalid axis %%d for a %%d-D array.", axis, ndim);
%(fail)s
}
len_along_axis = PyArray_DIM(%(x)s, axis);
for (i = 0; i < splits_count; ++i) {
current_split_length = (npy_intp)(*(%(splits_dtype)s*)PyArray_GETPTR1(%(splits)s, i));
if (current_split_length < 0) {
PyErr_Format(PyExc_ValueError,
"Split: you try to take a negative number (%%ld) of elements.", current_split_length);
%(fail)s
}
sum_of_splits += current_split_length;
}
if (sum_of_splits != len_along_axis) {
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %%ld, expected %%ld.", sum_of_splits, len_along_axis);
%(fail)s
}
/* Check outputs. */
split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
memcpy(split_dims, PyArray_DIMS(%(x)s), ndim * sizeof(npy_intp));
%(full_code_for_checking_outputs)s
/* Compute split. */
%(full_code_for_splitting)s
free(split_dims);
"""
for split_index, output in enumerate(outputs):
codes_for_checking_outputs.append("""
current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, %(split_index)s) );
if (%(output)s == NULL || !split_output_shape_is_correct(%(output)s, %(x)s, axis, current_split_length)) {
Py_XDECREF(%(output)s);
split_dims[axis] = current_split_length;
%(output)s = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, %(x_typenum)s, PyArray_IS_F_CONTIGUOUS(%(x)s));
if (%(output)s == NULL) {
PyErr_Format(PyExc_RuntimeError, "Split: unable to allocate an output.");
free(split_dims);
%(fail)s
}
}
""" % locals())
codes_for_splitting.append("""
current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, %(split_index)s));
data_offset = PyArray_STRIDE(%(x)s, axis) * current_split_start;
split_dims[axis] = current_split_length;
split_view = PyArray_New(&PyArray_Type,
ndim, split_dims,
%(x_typenum)s,
PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s) + data_offset,
%(x_itemsize)s,
PyArray_FLAGS(%(x)s),
NULL);
if (split_view == NULL) {
PyErr_Format(PyExc_RuntimeError, "Split: unable to create a view for a split.");
free(split_dims);
%(fail)s
}
if (PyArray_CopyInto(%(output)s, (PyArrayObject*)split_view) != 0) {
PyErr_Format(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
Py_XDECREF(split_view);
free(split_dims);
%(fail)s
}
Py_XDECREF(split_view);
current_split_start += current_split_length;
""" % locals())
full_code_for_checking_outputs = '\r\n'.join(codes_for_checking_outputs)
full_code_for_splitting = '\r\n'.join(codes_for_splitting)
return main_code % locals()
def addbroadcast(x, *axes):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论