提交 a14cb2bd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not use Numba objmode for supported AdvancedSubtensor operations

Use ScalarTypes in MakeSlice for compatibility with Numba
上级 a9c52dd1
import warnings
import numba
import numpy as np
from pytensor.graph import Type
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -17,7 +15,10 @@ from pytensor.tensor.subtensor import (
)
def create_index_func(node, objmode=False):
@numba_funcify.register(Subtensor)
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array."""
unique_names = unique_name_generator(
......@@ -40,13 +41,13 @@ def create_index_func(node, objmode=False):
raise ValueError()
set_or_inc = isinstance(
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
)
index_start_idx = 1 + int(set_or_inc)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
idx_list = getattr(op, "idx_list", None)
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
......@@ -61,8 +62,7 @@ def create_index_func(node, objmode=False):
indices_creation_src = f"indices = ({indices_creation_src})"
if set_or_inc:
fn_name = "incsubtensor"
if node.op.inplace:
if op.inplace:
index_prologue = f"z = {input_names[0]}"
else:
index_prologue = f"z = np.copy({input_names[0]})"
......@@ -74,84 +74,57 @@ def create_index_func(node, objmode=False):
else:
y_name = input_names[1]
if node.op.set_instead_of_inc:
if op.set_instead_of_inc:
function_name = "setsubtensor"
index_body = f"z[indices] = {y_name}"
else:
function_name = "incsubtensor"
index_body = f"z[indices] += {y_name}"
else:
fn_name = "subtensor"
function_name = "subtensor"
index_prologue = ""
index_body = f"z = {input_names[0]}[indices]"
if objmode:
output_var = node.outputs[0]
if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""
index_body = f"""
with objmode({output_sig}):
{index_body}
"""
subtensor_def_src = f"""
def {fn_name}({", ".join(input_names)}):
def {function_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return np.asarray(z)
"""
return subtensor_def_src
@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedSubtensor."),
UserWarning,
)
subtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env}
func = compile_function_src(
subtensor_def_src,
function_name=function_name,
global_env=globals() | {"np": np},
)
return numba_njit(subtensor_fn, boundscheck=True)
@numba_funcify.register(IncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedIncSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
UserWarning,
return numba_njit(func, boundscheck=True)
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
adv_idxs_dims = [
idx.type.ndim
for idx in idxs
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
]
if (
# Numba does not support indexes with more than one dimension
# Nor multiple vector indexes
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
# The default index implementation does not handle duplicate indices correctly
or (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and not op.ignore_duplicates
)
):
return generate_fallback_impl(op, node, **kwargs)
incsubtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
)
return numba_njit(incsubtensor_fn, boundscheck=True)
return numba_funcify_default_subtensor(op, node, **kwargs)
@numba_funcify.register(AdvancedIncSubtensor1)
......
......@@ -21,7 +21,12 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.basic import (
ScalarFromTensor,
alloc,
get_underlying_scalar_constant_value,
nonzero,
)
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
......@@ -168,8 +173,16 @@ def as_index_literal(
if isinstance(idx, Constant):
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
if isinstance(getattr(idx, "type", None), SliceType):
idx = slice(*idx.owner.inputs)
if isinstance(idx, Variable):
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return as_index_literal(idx.owner.inputs[0])
if isinstance(idx.type, SliceType):
idx = slice(*idx.owner.inputs)
if isinstance(idx, slice):
return slice(
......
......@@ -18,7 +18,7 @@ def as_int_none_variable(x):
return NoneConst
elif NoneConst.equals(x):
return x
x = pytensor.tensor.as_tensor_variable(x, ndim=0)
x = pytensor.scalar.as_scalar(x)
if x.type.dtype not in integer_dtypes:
raise TypeError("index must be integers")
return x
......
import contextlib
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import (
......@@ -48,8 +51,8 @@ def test_Subtensor(x, indices):
@pytest.mark.parametrize(
"x, indices",
[
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
],
)
def test_AdvancedSubtensor1(x, indices):
......@@ -69,21 +72,46 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize(
"x, indices",
"x, indices, objmode_needed",
[
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(0, [1, 2, 2, 3]),
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(np.array([True, False, False])),
False,
),
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(bool)),
True,
),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
True,
),
],
)
def test_AdvancedSubtensor(x, indices):
@pytest.mark.filterwarnings("error")
def test_AdvancedSubtensor(x, indices, objmode_needed):
"""Test NumPy's advanced indexing in more than one dimension."""
out_pt = x[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSubtensor's perform method",
)
if objmode_needed
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
......@@ -194,35 +222,120 @@ def test_AdvancedIncSubtensor1(x, y, indices):
@pytest.mark.parametrize(
"x, y, indices",
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode",
[
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-np.arange(3 * 5).reshape(3, 5),
(slice(None, None, 2), [1, 2, 3]),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-99,
(slice(None, None, 2), [1, 2, 3], -1),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-99, # Broadcasted value
(slice(None, None, 2), [1, 2, 3]),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-np.arange(4 * 5).reshape(4, 5),
(0, [1, 2, 2, 3]),
True,
False,
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
[-99], # Broadcsasted value
(0, [1, 2, 2, 3]),
True,
False,
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
-np.arange(3),
(np.eye(3).astype(bool)),
False,
True,
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 5))),
([1, 2], [2, 3]),
False,
True,
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 4))),
([1, 2], slice(None), [3, 4]),
False,
True,
True,
),
pytest.param(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]),
False,
True,
True,
),
],
)
def test_AdvancedIncSubtensor(x, y, indices):
@pytest.mark.filterwarnings("error")
def test_AdvancedIncSubtensor(
x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode
):
out_pt = set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
out_pt = inc_subtensor(x[indices], y)
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
)
if set_requires_objmode
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [])
out_pt = inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
)
if inc_requires_objmode
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [])
x_pt = x.type()
out_pt = set_subtensor(x_pt[indices], y)
......@@ -231,4 +344,12 @@ def test_AdvancedIncSubtensor(x, y, indices):
out_pt.owner.op.inplace = True
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
)
if set_requires_objmode
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [x.data])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论