Unverified 提交 d9b10859 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Don't run local uint constant indices in C/Python backends (#1335)

* Let numpy methods handle integer size problems in AdvancedSubtensor1 * Don't run `local_uint_constant_indices` in C/python backend Indices are always cast to int64 by the underlying methods. Also don't run in specialize, to reduce number of passes. Other rewrites may introduce temporar indexing operations (such as x.shape[i]) which always default to int64, and it's useless to optimize immediately.
上级 afb76951
......@@ -489,7 +489,6 @@ PYTORCH = Mode(
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
"scan_save_mem_prealloc",
],
),
......
......@@ -5,7 +5,6 @@ from collections.abc import Iterable
import numpy as np
import pytensor
import pytensor.scalar.basic as ps
from pytensor import compile
from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable
......@@ -14,8 +13,11 @@ from pytensor.graph.rewriting.basic import (
copy_stack_trace,
in2out,
node_rewriter,
out2in,
)
from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant, ScalarType
from pytensor.scalar import constant as scalar_constant
from pytensor.tensor.basic import (
Alloc,
Join,
......@@ -31,6 +33,7 @@ from pytensor.tensor.basic import (
register_infer_shape,
switch,
)
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
......@@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
remove_dim = []
node_inputs_idx = 1
for dim, elem in enumerate(idx):
if isinstance(elem, (ps.ScalarType)):
if isinstance(elem, ScalarType):
# The idx is a ScalarType, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx]
if isinstance(dim_index, ps.ScalarConstant):
if isinstance(dim_index, ScalarConstant):
dim_index = dim_index.value
if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
......@@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node):
(idx,) = idxs
if isinstance(idx, ps.ScalarType | TensorType):
if isinstance(idx, ScalarType | TensorType):
old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1):
......@@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node):
and node.op.set_instead_of_inc
and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Elemwise)
and isinstance(node.inputs[1].owner.op.scalar_op, ps.Add)
and isinstance(node.inputs[1].owner.op.scalar_op, Add)
):
addn = node.inputs[1].owner
subn = None
......@@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node):
return [merged_subtensors]
@register_specialize
@node_rewriter(
[
Subtensor,
......@@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node):
if dtype == index_val.dtype:
continue
if index_val.ndim > 0:
new_index = pytensor.tensor.as_tensor_variable(
index_val.astype(dtype), dtype=dtype
)
if isinstance(index.type, TensorType):
new_index = tensor_constant(index_val.astype(dtype), dtype=dtype)
else:
new_index = ps.constant(index_val.astype(dtype), dtype=dtype)
new_index = scalar_constant(index_val.astype(dtype), dtype=dtype)
new_indices[i] = new_index
has_new_index = True
......@@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node):
return [new_out]
compile.optdb.register(
local_uint_constant_indices.__name__,
out2in(local_uint_constant_indices),
# We don't include in the Python / C because those always cast indices to int64 internally.
"numba",
"jax",
# After specialization and uncanonicalization
# Other rewrites don't worry about the dtype of the indices
# And can cause unnecessary passes of this optimization
# Such as x.shape[np.int(0)] -> x.shape[np.uint(0)]
position=4,
)
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
......
......@@ -3,7 +3,6 @@ import sys
import warnings
from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby
from textwrap import dedent
from typing import cast, overload
import numpy as np
......@@ -19,7 +18,7 @@ from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2
from pytensor.npy_2_compat import numpy_version, using_numpy_2
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import (
......@@ -2130,24 +2129,6 @@ class AdvancedSubtensor1(COp):
else:
o = None
# If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
# int64 on 64-bit machines), numpy may raise the following error:
# TypeError: array cannot be safely cast to required type.
# We need to check if values in i can fit in numpy.intp, because
# if they don't, that should be an error (no array can have that
# many elements on a 32-bit arch).
if i.dtype != np.intp:
i_ = np.asarray(i, dtype=np.intp)
if not np.can_cast(i.dtype, np.intp):
# Check if there was actually an incorrect conversion
if np.any(i != i_):
raise IndexError(
"index contains values that are bigger "
"than the maximum array size on this system.",
i,
)
i = i_
out[0] = x.take(i, axis=0, out=o)
def connection_pattern(self, node):
......@@ -2187,16 +2168,6 @@ class AdvancedSubtensor1(COp):
x, ilist = ishapes
return [ilist + x[1:]]
def c_support_code(self, **kwargs):
# In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
# which is not defined. It should be NPY_MIN_LONG instead in that case.
return npy_2_compat_header() + dedent(
"""\
#ifndef MIN_LONG
#define MIN_LONG NPY_MIN_LONG
#endif"""
)
def c_code(self, node, name, input_names, output_names, sub):
if self.__class__ is not AdvancedSubtensor1:
raise MethodNotDefined(
......@@ -2207,61 +2178,16 @@ class AdvancedSubtensor1(COp):
output_name = output_names[0]
fail = sub["fail"]
return f"""
PyArrayObject *indices;
int i_type = PyArray_TYPE({i_name});
if (i_type != NPY_INTP) {{
// Cast {i_name} to NPY_INTP (expected by PyArray_TakeFrom),
// if all values fit.
if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
PyArray_SIZE({i_name}) > 0) {{
npy_int64 min_val, max_val;
PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS,
NULL);
if (py_min_val == NULL) {{
{fail};
}}
min_val = PyLong_AsLongLong(py_min_val);
Py_DECREF(py_min_val);
if (min_val == -1 && PyErr_Occurred()) {{
{fail};
}}
PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS,
NULL);
if (py_max_val == NULL) {{
{fail};
}}
max_val = PyLong_AsLongLong(py_max_val);
Py_DECREF(py_max_val);
if (max_val == -1 && PyErr_Occurred()) {{
{fail};
}}
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{
PyErr_SetString(PyExc_IndexError,
"Index contains values "
"that are bigger than the maximum array "
"size on this system.");
{fail};
}}
}}
indices = (PyArrayObject*) PyArray_Cast({i_name}, NPY_INTP);
if (indices == NULL) {{
{fail};
}}
}}
else {{
indices = {i_name};
Py_INCREF(indices);
}}
if ({output_name} != NULL) {{
npy_intp nd, i, *shape;
nd = PyArray_NDIM({a_name}) + PyArray_NDIM(indices) - 1;
nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
if (PyArray_NDIM({output_name}) != nd) {{
Py_CLEAR({output_name});
}}
else {{
shape = PyArray_DIMS({output_name});
for (i = 0; i < PyArray_NDIM(indices); i++) {{
if (shape[i] != PyArray_DIMS(indices)[i]) {{
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
Py_CLEAR({output_name});
break;
}}
......@@ -2269,7 +2195,7 @@ class AdvancedSubtensor1(COp):
if ({output_name} != NULL) {{
for (; i < nd; i++) {{
if (shape[i] != PyArray_DIMS({a_name})[
i-PyArray_NDIM(indices)+1]) {{
i-PyArray_NDIM({i_name})+1]) {{
Py_CLEAR({output_name});
break;
}}
......@@ -2278,13 +2204,12 @@ class AdvancedSubtensor1(COp):
}}
}}
{output_name} = (PyArrayObject*)PyArray_TakeFrom(
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
Py_DECREF(indices);
{a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE);
if ({output_name} == NULL) {fail};
"""
def c_code_cache_version(self):
return (0, 1, 2, 3)
return (4,)
advanced_subtensor1 = AdvancedSubtensor1()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论