提交 f6958407 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Reuse output buffer in C-impl of Join

上级 3de303d2
...@@ -2541,7 +2541,7 @@ class Join(COp): ...@@ -2541,7 +2541,7 @@ class Join(COp):
) )
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (7,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
axis, *arrays = inputs axis, *arrays = inputs
...@@ -2580,16 +2580,86 @@ class Join(COp): ...@@ -2580,16 +2580,86 @@ class Join(COp):
code = f""" code = f"""
int axis = {axis_def} int axis = {axis_def}
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}}; PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
PyObject* arrays_tuple = PyTuple_New({n}); int out_is_valid = {out} != NULL;
{axis_check} {axis_check}
Py_XDECREF({out}); if (out_is_valid) {{
{copy_arrays_to_tuple} // Check if we can reuse output
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis); npy_intp join_size = 0;
Py_DECREF(arrays_tuple); npy_intp out_shape[{ndim}];
if(!{out}){{ npy_intp *shape = PyArray_SHAPE(arrays[0]);
{fail}
for (int i = 0; i < {n}; i++) {{
if (PyArray_NDIM(arrays[i]) != {ndim}) {{
PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
{fail}
}}
join_size += PyArray_SHAPE(arrays[i])[axis];
if (i > 0){{
for (int j = 0; j < {ndim}; j++) {{
if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
{fail}
}}
}}
}}
}}
memcpy(out_shape, shape, {ndim} * sizeof(npy_intp));
out_shape[axis] = join_size;
for (int i = 0; i < {ndim}; i++) {{
out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]);
}}
}}
if (!out_is_valid) {{
// Use PyArray_Concatenate
Py_XDECREF({out});
PyObject* arrays_tuple = PyTuple_New({n});
{copy_arrays_to_tuple}
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
Py_DECREF(arrays_tuple);
if(!{out}){{
{fail}
}}
}}
else {{
// Copy the data to the pre-allocated output buffer
// Create view into output buffer
PyArrayObject_fields *view;
// PyArray_NewFromDescr steals a reference to descr, so we need to increase it
Py_INCREF(PyArray_DESCR({out}));
view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
PyArray_DESCR({out}),
{ndim},
PyArray_SHAPE(arrays[0]),
PyArray_STRIDES({out}),
PyArray_DATA({out}),
NPY_ARRAY_WRITEABLE,
NULL);
if (view == NULL) {{
{fail}
}}
// Copy data into output buffer
for (int i = 0; i < {n}; i++) {{
view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
Py_DECREF(view);
{fail}
}}
view->data += (view->dimensions[axis] * view->strides[axis]);
}}
Py_DECREF(view);
}} }}
""" """
return code return code
......
...@@ -117,6 +117,7 @@ from pytensor.tensor.type import ( ...@@ -117,6 +117,7 @@ from pytensor.tensor.type import (
ivector, ivector,
lscalar, lscalar,
lvector, lvector,
matrices,
matrix, matrix,
row, row,
scalar, scalar,
...@@ -1762,7 +1763,7 @@ class TestJoinAndSplit: ...@@ -1762,7 +1763,7 @@ class TestJoinAndSplit:
got = f(-2) got = f(-2)
assert np.allclose(got, want) assert np.allclose(got, want)
with pytest.raises(IndexError): with pytest.raises(ValueError):
f(-3) f(-3)
@pytest.mark.parametrize("py_impl", (False, True)) @pytest.mark.parametrize("py_impl", (False, True))
...@@ -1805,7 +1806,7 @@ class TestJoinAndSplit: ...@@ -1805,7 +1806,7 @@ class TestJoinAndSplit:
got = f() got = f()
assert np.allclose(got, want) assert np.allclose(got, want)
with pytest.raises(IndexError): with pytest.raises(ValueError):
join(-3, a, b) join(-3, a, b)
with impl_ctxt: with impl_ctxt:
...@@ -2152,6 +2153,32 @@ class TestJoinAndSplit: ...@@ -2152,6 +2153,32 @@ class TestJoinAndSplit:
assert np.allclose(r, expected) assert np.allclose(r, expected)
assert r.base is x_test assert r.base is x_test
@pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}")
@pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"])
@pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}")
@pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"])
@config.change_flags(cmodule__warn_no_version=False)
def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0):
pytest.skip("Redundant parametrization")
n = 64
inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef")
out = join(axis, *inputs)
fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True)
fn.vm.allow_gc = gc
test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs]
if memory_layout == "C-contiguous":
pass
elif memory_layout == "F-contiguous":
test_values = [t.T for t in test_values]
elif memory_layout == "Mixed":
test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)]
else:
raise ValueError
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
benchmark(fn, *test_values)
def test_TensorFromScalar(): def test_TensorFromScalar():
s = ps.constant(56) s = ps.constant(56)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论