提交 4311f893 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Specialize AdvancedSubtensor1 mode for compile time valid indices

上级 d351b09d
...@@ -2120,16 +2120,12 @@ class AdvancedSubtensor1(COp): ...@@ -2120,16 +2120,12 @@ class AdvancedSubtensor1(COp):
out_shape = (ilist_.type.shape[0], *x_.type.shape[1:]) out_shape = (ilist_.type.shape[0], *x_.type.shape[1:])
return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()]) return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()])
def perform(self, node, inp, out_): def perform(self, node, inp, output_storage):
x, i = inp x, i = inp
(out,) = out_
# Copy always implied by numpy advanced indexing semantic.
if out[0] is not None and out[0].shape == (len(i),) + x.shape[1:]:
o = out[0]
else:
o = None
out[0] = x.take(i, axis=0, out=o) # Numpy take is always slower when out is provided
# https://github.com/numpy/numpy/issues/28636
output_storage[0][0] = x.take(i, axis=0, out=None)
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [[True], *([False] for _ in node.inputs[1:])] rval = [[True], *([False] for _ in node.inputs[1:])]
...@@ -2174,42 +2170,83 @@ class AdvancedSubtensor1(COp): ...@@ -2174,42 +2170,83 @@ class AdvancedSubtensor1(COp):
"c_code defined for AdvancedSubtensor1, not for child class", "c_code defined for AdvancedSubtensor1, not for child class",
type(self), type(self),
) )
x, idxs = node.inputs
if self._idx_may_be_invalid(x, idxs):
mode = "NPY_RAISE"
else:
# We can know ahead of time that all indices are valid, so we can use a faster mode
mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP
a_name, i_name = input_names[0], input_names[1] a_name, i_name = input_names[0], input_names[1]
output_name = output_names[0] output_name = output_names[0]
fail = sub["fail"] fail = sub["fail"]
return f""" if mode == "NPY_RAISE":
if ({output_name} != NULL) {{ # numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer
npy_intp nd, i, *shape; # We can remove this special case after https://github.com/numpy/numpy/issues/28636
nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1; manage_pre_allocated_out = f"""
if (PyArray_NDIM({output_name}) != nd) {{ if ({output_name} != NULL) {{
// Numpy TakeFrom is always slower when copying
// https://github.com/numpy/numpy/issues/28636
Py_CLEAR({output_name}); Py_CLEAR({output_name});
}} }}
else {{ """
shape = PyArray_DIMS({output_name}); else:
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{ manage_pre_allocated_out = f"""
if (shape[i] != PyArray_DIMS({i_name})[i]) {{ if ({output_name} != NULL) {{
Py_CLEAR({output_name}); npy_intp nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
break; if (PyArray_NDIM({output_name}) != nd) {{
}} Py_CLEAR({output_name});
}} }}
if ({output_name} != NULL) {{ else {{
for (; i < nd; i++) {{ int i;
if (shape[i] != PyArray_DIMS({a_name})[ npy_intp* shape = PyArray_DIMS({output_name});
i-PyArray_NDIM({i_name})+1]) {{ for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
Py_CLEAR({output_name}); Py_CLEAR({output_name});
break; break;
}} }}
}} }}
if ({output_name} != NULL) {{
for (; i < nd; i++) {{
if (shape[i] != PyArray_DIMS({a_name})[i-PyArray_NDIM({i_name})+1]) {{
Py_CLEAR({output_name});
break;
}}
}}
}}
}} }}
}} }}
}} """
return f"""
{manage_pre_allocated_out}
{output_name} = (PyArrayObject*)PyArray_TakeFrom( {output_name} = (PyArrayObject*)PyArray_TakeFrom(
{a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE); {a_name}, (PyObject*){i_name}, 0, {output_name}, {mode});
if ({output_name} == NULL) {fail}; if ({output_name} == NULL) {fail};
""" """
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (5,)
@staticmethod
def _idx_may_be_invalid(x, idx) -> bool:
if idx.type.shape[0] == 0:
# Empty index is always valid
return False
if x.type.shape[0] is None:
# We can't know if in index is valid if we don't know the length of x
return True
if not isinstance(idx, Constant):
# This is conservative, but we don't try to infer lower/upper bound symbolically
return True
shape0 = x.type.shape[0]
min_idx, max_idx = idx.data.min(), idx.data.max()
return not (min_idx >= 0 or min_idx >= -shape0) and (
max_idx < 0 or max_idx < shape0
)
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
......
...@@ -3003,3 +3003,28 @@ def test_flip(size: tuple[int]): ...@@ -3003,3 +3003,28 @@ def test_flip(size: tuple[int]):
z = flip(x_pt, axis=list(axes)) z = flip(x_pt, axis=list(axes))
f = pytensor.function([x_pt], z, mode="FAST_COMPILE") f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
class TestBenchmarks:
@pytest.mark.parametrize(
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
)
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
def test_advanced_subtensor1(self, static_shape, gc, benchmark):
x = vector("x", shape=(85 if static_shape else None,))
x_values = np.random.normal(size=(85,))
idxs_values = np.arange(85).repeat(11)
# With static shape and constant indices we know all idxs are valid
# And can use faster mode in numpy.take
out = x[idxs_values]
fn = pytensor.function(
[x],
pytensor.Out(out, borrow=True),
on_unused_input="ignore",
trust_input=True,
)
fn.vm.allow_gc = gc
benchmark(fn, x_values, idxs_values)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论