提交 4e59f21a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Make Split C-impl return a view

上级 7f6676d6
...@@ -2171,8 +2171,6 @@ class Split(COp): ...@@ -2171,8 +2171,6 @@ class Split(COp):
array([3, 4]) array([3, 4])
>>> c >>> c
array([5]) array([5])
TODO: Don't make a copy in C impl
""" """
len_splits = None len_splits = None
...@@ -2285,75 +2283,63 @@ class Split(COp): ...@@ -2285,75 +2283,63 @@ class Split(COp):
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def c_support_code(self, **kwargs):
return """
/* Return 1 if output has the correct shape. */
int split_output_shape_is_correct (
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
) {
return
PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
&& memcmp(
PyArray_DIMS(output),
PyArray_DIMS(array_to_split),
axis_to_split * sizeof(npy_intp)
) == 0
&& memcmp(
PyArray_DIMS(output) + axis_to_split + 1,
PyArray_DIMS(array_to_split) + axis_to_split + 1,
(PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
) == 0
&& split_size == PyArray_DIM(output, axis_to_split);
}
"""
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if self.len_splits == 0: if self.len_splits == 0:
# There are no outputs, then nothing to do. # This would be a view Op, anyway shouldn't be triggered
return "" raise NotImplementedError()
# outputs_pointers lists the addresses of the pointers to the outputs. # outputs_pointers lists the addresses of the pointers to the outputs.
outputs_pointers = "&" + (", &".join(outputs)) outputs_pointers = "&" + (", &".join(outputs))
x, axis, splits = inputs x, axis, splits = inputs
fail = sub["fail"] fail = sub["fail"]
x_typenum = np.dtype(node.inputs[0].dtype).num
x_itemsize = np.dtype(node.inputs[0].dtype).itemsize
axis_dtype = node.inputs[1].type.dtype_specs()[1]
splits_dtype = node.inputs[2].type.dtype_specs()[1] splits_dtype = node.inputs[2].type.dtype_specs()[1]
expected_splits_count = self.len_splits len_splits = self.len_splits
ndim = node.inputs[0].type.ndim
# Most times axis is constant, inline it
# This is safe to do because the hash of the c_code includes the constant signature
if isinstance(node.inputs[1], Constant):
static_axis = int(node.inputs[1].data)
static_axis = normalize_axis_index(static_axis, ndim)
axis_def = f"{static_axis};"
axis_check = ""
else:
axis_dtype = node.inputs[1].type.dtype_specs()[1]
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
axis_check = f"""
if (axis < 0){{
axis = ndim + axis;
}}
if (axis >= ndim || axis < 0) {{
PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
{fail}
}}
"""
return f""" return f"""
int ndim = PyArray_NDIM({x}); int ndim = {ndim};
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0)); int axis = {axis_def}
int splits_count = PyArray_DIM({splits}, 0); int splits_count = PyArray_DIM({splits}, 0);
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0; npy_intp sum_of_splits = 0, current_split_start = 0;
npy_intp* split_dims = NULL;
PyObject* split_view = NULL;
npy_intp data_offset;
int i;
PyArrayObject** outputs[] = {{{outputs_pointers}}}; PyArrayObject** outputs[] = {{{outputs_pointers}}};
npy_intp split_dims[ndim];
/* Check inputs. */ /* Check inputs. */
if (PyArray_NDIM({x}) != ndim) {{
if (splits_count != {expected_splits_count}) {{ PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
PyErr_Format(PyExc_ValueError,
"Split: splits count (%d) != expected count (%d).", splits_count, {expected_splits_count});
{fail} {fail}
}} }}
if (splits_count != {len_splits}) {{
if (axis < 0) {{ PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, {len_splits});
axis += ndim;
}}
if (axis < 0 || axis >= ndim) {{
PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
{fail} {fail}
}} }}
len_along_axis = PyArray_DIM({x}, axis);
for (i = 0; i < splits_count; ++i) {{ {axis_check};
current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
for (int i = 0; i < splits_count; ++i) {{
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
if (current_split_length < 0) {{ if (current_split_length < 0) {{
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Split: you try to take a negative number (%ld) of elements.", current_split_length); "Split: you try to take a negative number (%ld) of elements.", current_split_length);
...@@ -2361,66 +2347,43 @@ class Split(COp): ...@@ -2361,66 +2347,43 @@ class Split(COp):
}} }}
sum_of_splits += current_split_length; sum_of_splits += current_split_length;
}} }}
if (sum_of_splits != len_along_axis) {{ if (sum_of_splits != PyArray_DIM({x}, axis)) {{
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis); PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({x}, axis));
{fail}
}}
/* Check outputs. */
split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
if (split_dims == NULL) {{
PyErr_NoMemory();
{fail} {fail}
}} }}
/* Compute split. */
memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp)); memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp));
for (i = 0; i < splits_count; ++i) {{ for (int i = 0; i < splits_count; ++i) {{
PyArrayObject** output = outputs[i]; Py_XDECREF(*outputs[i]);
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
if (*output == NULL || !split_output_shape_is_correct(*output, {x}, axis, current_split_length)) {{
Py_XDECREF(*output);
split_dims[axis] = current_split_length;
*output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, {x_typenum}, PyArray_IS_F_CONTIGUOUS({x}));
if (outputs == NULL) {{
PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
free(split_dims);
{fail}
}}
}}
}}
/* Compute split. */ // Create view of input
npy_intp data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
for (i = 0; i < splits_count; ++i) {{ int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
split_dims[axis] = current_split_length; split_dims[axis] = current_split_length;
split_view = PyArray_New(&PyArray_Type, PyArray_Descr *descr = PyArray_DESCR({x});
Py_INCREF(descr);
*outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
descr, // PyArray_NewFromDescr steals this reference
ndim, split_dims, ndim, split_dims,
{x_typenum},
PyArray_STRIDES({x}), PyArray_STRIDES({x}),
PyArray_BYTES({x}) + data_offset, PyArray_BYTES({x}) + data_offset,
{x_itemsize}, PyArray_FLAGS({x}) & ~NPY_ARRAY_OWNDATA,
PyArray_FLAGS({x}),
NULL); NULL);
if (split_view == NULL) {{
if (*outputs[i] == NULL) {{
PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split."); PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
free(split_dims);
{fail}
}}
if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
Py_XDECREF(split_view);
free(split_dims);
{fail} {fail}
}} }}
Py_XDECREF(split_view);
// Set as a view of input
Py_INCREF((PyObject*){x});
PyArray_SetBaseObject(*outputs[i], (PyObject*){x});
// Update split slice pointer
current_split_start += current_split_length; current_split_start += current_split_length;
}} }}
free(split_dims);
""" """
......
...@@ -2172,11 +2172,7 @@ class TestJoinAndSplit: ...@@ -2172,11 +2172,7 @@ class TestJoinAndSplit:
res = f(x_test) res = f(x_test)
for r, expected in zip(res, ([], [0, 1, 2], [3, 4]), strict=True): for r, expected in zip(res, ([], [0, 1, 2], [3, 4]), strict=True):
assert np.allclose(r, expected) assert np.allclose(r, expected)
if linker == "py":
assert r.base is x_test assert r.base is x_test
else:
# C impl always makes a copy
assert r.base is not x_test
def test_TensorFromScalar(): def test_TensorFromScalar():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论