Unverified 提交 d9b10859 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Don't run local uint constant indices in C/Python backends (#1335)

* Let numpy methods handle integer size problems in AdvancedSubtensor1 * Don't run `local_uint_constant_indices` in C/python backend Indices are always cast to int64 by the underlying methods. Also don't run in specialize, to reduce number of passes. Other rewrites may introduce temporar indexing operations (such as x.shape[i]) which always default to int64, and it's useless to optimize immediately.
上级 afb76951
...@@ -489,7 +489,6 @@ PYTORCH = Mode( ...@@ -489,7 +489,6 @@ PYTORCH = Mode(
"BlasOpt", "BlasOpt",
"fusion", "fusion",
"inplace", "inplace",
"local_uint_constant_indices",
"scan_save_mem_prealloc", "scan_save_mem_prealloc",
], ],
), ),
......
...@@ -5,7 +5,6 @@ from collections.abc import Iterable ...@@ -5,7 +5,6 @@ from collections.abc import Iterable
import numpy as np import numpy as np
import pytensor import pytensor
import pytensor.scalar.basic as ps
from pytensor import compile from pytensor import compile
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
...@@ -14,8 +13,11 @@ from pytensor.graph.rewriting.basic import ( ...@@ -14,8 +13,11 @@ from pytensor.graph.rewriting.basic import (
copy_stack_trace, copy_stack_trace,
in2out, in2out,
node_rewriter, node_rewriter,
out2in,
) )
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant, ScalarType
from pytensor.scalar import constant as scalar_constant
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
Join, Join,
...@@ -31,6 +33,7 @@ from pytensor.tensor.basic import ( ...@@ -31,6 +33,7 @@ from pytensor.tensor.basic import (
register_infer_shape, register_infer_shape,
switch, switch,
) )
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
...@@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): ...@@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
remove_dim = [] remove_dim = []
node_inputs_idx = 1 node_inputs_idx = 1
for dim, elem in enumerate(idx): for dim, elem in enumerate(idx):
if isinstance(elem, (ps.ScalarType)): if isinstance(elem, ScalarType):
# The idx is a ScalarType, ie a Type. This means the actual index # The idx is a ScalarType, ie a Type. This means the actual index
# is contained in node.inputs[1] # is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx] dim_index = node.inputs[node_inputs_idx]
if isinstance(dim_index, ps.ScalarConstant): if isinstance(dim_index, ScalarConstant):
dim_index = dim_index.value dim_index = dim_index.value
if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]: if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim) remove_dim.append(dim)
...@@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node):
(idx,) = idxs (idx,) = idxs
if isinstance(idx, ps.ScalarType | TensorType): if isinstance(idx, ScalarType | TensorType):
old_idx, idx = idx, node.inputs[1] old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx) assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1): elif isinstance(node.op, AdvancedSubtensor1):
...@@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node): ...@@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node):
and node.op.set_instead_of_inc and node.op.set_instead_of_inc
and node.inputs[1].owner and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Elemwise) and isinstance(node.inputs[1].owner.op, Elemwise)
and isinstance(node.inputs[1].owner.op.scalar_op, ps.Add) and isinstance(node.inputs[1].owner.op.scalar_op, Add)
): ):
addn = node.inputs[1].owner addn = node.inputs[1].owner
subn = None subn = None
...@@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node): ...@@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node):
return [merged_subtensors] return [merged_subtensors]
@register_specialize
@node_rewriter( @node_rewriter(
[ [
Subtensor, Subtensor,
...@@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node): ...@@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node):
if dtype == index_val.dtype: if dtype == index_val.dtype:
continue continue
if index_val.ndim > 0: if isinstance(index.type, TensorType):
new_index = pytensor.tensor.as_tensor_variable( new_index = tensor_constant(index_val.astype(dtype), dtype=dtype)
index_val.astype(dtype), dtype=dtype
)
else: else:
new_index = ps.constant(index_val.astype(dtype), dtype=dtype) new_index = scalar_constant(index_val.astype(dtype), dtype=dtype)
new_indices[i] = new_index new_indices[i] = new_index
has_new_index = True has_new_index = True
...@@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node): ...@@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node):
return [new_out] return [new_out]
compile.optdb.register(
local_uint_constant_indices.__name__,
out2in(local_uint_constant_indices),
# We don't include in the Python / C because those always cast indices to int64 internally.
"numba",
"jax",
# After specialization and uncanonicalization
# Other rewrites don't worry about the dtype of the indices
# And can cause unnecessary passes of this optimization
# Such as x.shape[np.int(0)] -> x.shape[np.uint(0)]
position=4,
)
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe") @register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
......
...@@ -3,7 +3,6 @@ import sys ...@@ -3,7 +3,6 @@ import sys
import warnings import warnings
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby from itertools import chain, groupby
from textwrap import dedent
from typing import cast, overload from typing import cast, overload
import numpy as np import numpy as np
...@@ -19,7 +18,7 @@ from pytensor.graph.type import Type ...@@ -19,7 +18,7 @@ from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2 from pytensor.npy_2_compat import numpy_version, using_numpy_2
from pytensor.printing import Printer, pprint, set_precedence from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import ( from pytensor.tensor import (
...@@ -2130,24 +2129,6 @@ class AdvancedSubtensor1(COp): ...@@ -2130,24 +2129,6 @@ class AdvancedSubtensor1(COp):
else: else:
o = None o = None
# If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
# int64 on 64-bit machines), numpy may raise the following error:
# TypeError: array cannot be safely cast to required type.
# We need to check if values in i can fit in numpy.intp, because
# if they don't, that should be an error (no array can have that
# many elements on a 32-bit arch).
if i.dtype != np.intp:
i_ = np.asarray(i, dtype=np.intp)
if not np.can_cast(i.dtype, np.intp):
# Check if there was actually an incorrect conversion
if np.any(i != i_):
raise IndexError(
"index contains values that are bigger "
"than the maximum array size on this system.",
i,
)
i = i_
out[0] = x.take(i, axis=0, out=o) out[0] = x.take(i, axis=0, out=o)
def connection_pattern(self, node): def connection_pattern(self, node):
...@@ -2187,16 +2168,6 @@ class AdvancedSubtensor1(COp): ...@@ -2187,16 +2168,6 @@ class AdvancedSubtensor1(COp):
x, ilist = ishapes x, ilist = ishapes
return [ilist + x[1:]] return [ilist + x[1:]]
def c_support_code(self, **kwargs):
# In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
# which is not defined. It should be NPY_MIN_LONG instead in that case.
return npy_2_compat_header() + dedent(
"""\
#ifndef MIN_LONG
#define MIN_LONG NPY_MIN_LONG
#endif"""
)
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
if self.__class__ is not AdvancedSubtensor1: if self.__class__ is not AdvancedSubtensor1:
raise MethodNotDefined( raise MethodNotDefined(
...@@ -2207,61 +2178,16 @@ class AdvancedSubtensor1(COp): ...@@ -2207,61 +2178,16 @@ class AdvancedSubtensor1(COp):
output_name = output_names[0] output_name = output_names[0]
fail = sub["fail"] fail = sub["fail"]
return f""" return f"""
PyArrayObject *indices;
int i_type = PyArray_TYPE({i_name});
if (i_type != NPY_INTP) {{
// Cast {i_name} to NPY_INTP (expected by PyArray_TakeFrom),
// if all values fit.
if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
PyArray_SIZE({i_name}) > 0) {{
npy_int64 min_val, max_val;
PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS,
NULL);
if (py_min_val == NULL) {{
{fail};
}}
min_val = PyLong_AsLongLong(py_min_val);
Py_DECREF(py_min_val);
if (min_val == -1 && PyErr_Occurred()) {{
{fail};
}}
PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS,
NULL);
if (py_max_val == NULL) {{
{fail};
}}
max_val = PyLong_AsLongLong(py_max_val);
Py_DECREF(py_max_val);
if (max_val == -1 && PyErr_Occurred()) {{
{fail};
}}
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{
PyErr_SetString(PyExc_IndexError,
"Index contains values "
"that are bigger than the maximum array "
"size on this system.");
{fail};
}}
}}
indices = (PyArrayObject*) PyArray_Cast({i_name}, NPY_INTP);
if (indices == NULL) {{
{fail};
}}
}}
else {{
indices = {i_name};
Py_INCREF(indices);
}}
if ({output_name} != NULL) {{ if ({output_name} != NULL) {{
npy_intp nd, i, *shape; npy_intp nd, i, *shape;
nd = PyArray_NDIM({a_name}) + PyArray_NDIM(indices) - 1; nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
if (PyArray_NDIM({output_name}) != nd) {{ if (PyArray_NDIM({output_name}) != nd) {{
Py_CLEAR({output_name}); Py_CLEAR({output_name});
}} }}
else {{ else {{
shape = PyArray_DIMS({output_name}); shape = PyArray_DIMS({output_name});
for (i = 0; i < PyArray_NDIM(indices); i++) {{ for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
if (shape[i] != PyArray_DIMS(indices)[i]) {{ if (shape[i] != PyArray_DIMS({i_name})[i]) {{
Py_CLEAR({output_name}); Py_CLEAR({output_name});
break; break;
}} }}
...@@ -2269,7 +2195,7 @@ class AdvancedSubtensor1(COp): ...@@ -2269,7 +2195,7 @@ class AdvancedSubtensor1(COp):
if ({output_name} != NULL) {{ if ({output_name} != NULL) {{
for (; i < nd; i++) {{ for (; i < nd; i++) {{
if (shape[i] != PyArray_DIMS({a_name})[ if (shape[i] != PyArray_DIMS({a_name})[
i-PyArray_NDIM(indices)+1]) {{ i-PyArray_NDIM({i_name})+1]) {{
Py_CLEAR({output_name}); Py_CLEAR({output_name});
break; break;
}} }}
...@@ -2278,13 +2204,12 @@ class AdvancedSubtensor1(COp): ...@@ -2278,13 +2204,12 @@ class AdvancedSubtensor1(COp):
}} }}
}} }}
{output_name} = (PyArrayObject*)PyArray_TakeFrom( {output_name} = (PyArrayObject*)PyArray_TakeFrom(
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE); {a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE);
Py_DECREF(indices);
if ({output_name} == NULL) {fail}; if ({output_name} == NULL) {fail};
""" """
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 1, 2, 3) return (4,)
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论