提交 a14cb2bd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not use Numba objmode for supported AdvancedSubtensor operations

Use ScalarTypes in MakeSlice for compatibility with Numba
上级 a9c52dd1
import warnings
import numba
import numpy as np import numpy as np
from pytensor.graph import Type from pytensor.graph import Type
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -17,7 +15,10 @@ from pytensor.tensor.subtensor import ( ...@@ -17,7 +15,10 @@ from pytensor.tensor.subtensor import (
) )
def create_index_func(node, objmode=False): @numba_funcify.register(Subtensor)
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array.""" """Create a Python function that assembles and uses an index on an array."""
unique_names = unique_name_generator( unique_names = unique_name_generator(
...@@ -40,13 +41,13 @@ def create_index_func(node, objmode=False): ...@@ -40,13 +41,13 @@ def create_index_func(node, objmode=False):
raise ValueError() raise ValueError()
set_or_inc = isinstance( set_or_inc = isinstance(
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
) )
index_start_idx = 1 + int(set_or_inc) index_start_idx = 1 + int(set_or_inc)
input_names = [unique_names(v, force_unique=True) for v in node.inputs] input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:]) op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None) idx_list = getattr(op, "idx_list", None)
indices_creation_src = ( indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list) tuple(convert_indices(op_indices, idx) for idx in idx_list)
...@@ -61,8 +62,7 @@ def create_index_func(node, objmode=False): ...@@ -61,8 +62,7 @@ def create_index_func(node, objmode=False):
indices_creation_src = f"indices = ({indices_creation_src})" indices_creation_src = f"indices = ({indices_creation_src})"
if set_or_inc: if set_or_inc:
fn_name = "incsubtensor" if op.inplace:
if node.op.inplace:
index_prologue = f"z = {input_names[0]}" index_prologue = f"z = {input_names[0]}"
else: else:
index_prologue = f"z = np.copy({input_names[0]})" index_prologue = f"z = np.copy({input_names[0]})"
...@@ -74,84 +74,57 @@ def create_index_func(node, objmode=False): ...@@ -74,84 +74,57 @@ def create_index_func(node, objmode=False):
else: else:
y_name = input_names[1] y_name = input_names[1]
if node.op.set_instead_of_inc: if op.set_instead_of_inc:
function_name = "setsubtensor"
index_body = f"z[indices] = {y_name}" index_body = f"z[indices] = {y_name}"
else: else:
function_name = "incsubtensor"
index_body = f"z[indices] += {y_name}" index_body = f"z[indices] += {y_name}"
else: else:
fn_name = "subtensor" function_name = "subtensor"
index_prologue = "" index_prologue = ""
index_body = f"z = {input_names[0]}[indices]" index_body = f"z = {input_names[0]}[indices]"
if objmode:
output_var = node.outputs[0]
if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""
index_body = f"""
with objmode({output_sig}):
{index_body}
"""
subtensor_def_src = f""" subtensor_def_src = f"""
def {fn_name}({", ".join(input_names)}): def {function_name}({", ".join(input_names)}):
{index_prologue} {index_prologue}
{indices_creation_src} {indices_creation_src}
{index_body} {index_body}
return np.asarray(z) return np.asarray(z)
""" """
return subtensor_def_src func = compile_function_src(
subtensor_def_src,
function_name=function_name,
@numba_funcify.register(Subtensor) global_env=globals() | {"np": np},
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedSubtensor."),
UserWarning,
)
subtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env}
) )
return numba_njit(func, boundscheck=True)
return numba_njit(subtensor_fn, boundscheck=True)
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(IncSubtensor) @numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs): def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedIncSubtensor) idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
if objmode: adv_idxs_dims = [
warnings.warn( idx.type.ndim
("Numba will use object mode to allow run " "AdvancedIncSubtensor."), for idx in idxs
UserWarning, if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
) ]
incsubtensor_def_src = create_index_func(node, objmode=objmode) if (
# Numba does not support indexes with more than one dimension
global_env = {"np": np} # Nor multiple vector indexes
if objmode: (len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
global_env["objmode"] = numba.objmode # The default index implementation does not handle duplicate indices correctly
or (
incsubtensor_fn = compile_function_src( isinstance(op, AdvancedIncSubtensor)
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} and not op.set_instead_of_inc
and not op.ignore_duplicates
) )
):
return generate_fallback_impl(op, node, **kwargs)
return numba_njit(incsubtensor_fn, boundscheck=True) return numba_funcify_default_subtensor(op, node, **kwargs)
@numba_funcify.register(AdvancedIncSubtensor1) @numba_funcify.register(AdvancedIncSubtensor1)
......
...@@ -21,7 +21,12 @@ from pytensor.misc.safe_asarray import _asarray ...@@ -21,7 +21,12 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero from pytensor.tensor.basic import (
ScalarFromTensor,
alloc,
get_underlying_scalar_constant_value,
nonzero,
)
from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
...@@ -168,7 +173,15 @@ def as_index_literal( ...@@ -168,7 +173,15 @@ def as_index_literal(
if isinstance(idx, Constant): if isinstance(idx, Constant):
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
if isinstance(getattr(idx, "type", None), SliceType): if isinstance(idx, Variable):
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return as_index_literal(idx.owner.inputs[0])
if isinstance(idx.type, SliceType):
idx = slice(*idx.owner.inputs) idx = slice(*idx.owner.inputs)
if isinstance(idx, slice): if isinstance(idx, slice):
......
...@@ -18,7 +18,7 @@ def as_int_none_variable(x): ...@@ -18,7 +18,7 @@ def as_int_none_variable(x):
return NoneConst return NoneConst
elif NoneConst.equals(x): elif NoneConst.equals(x):
return x return x
x = pytensor.tensor.as_tensor_variable(x, ndim=0) x = pytensor.scalar.as_scalar(x)
if x.type.dtype not in integer_dtypes: if x.type.dtype not in integer_dtypes:
raise TypeError("index must be integers") raise TypeError("index must be integers")
return x return x
......
import contextlib
import numpy as np import numpy as np
import pytest import pytest
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
...@@ -48,8 +51,8 @@ def test_Subtensor(x, indices): ...@@ -48,8 +51,8 @@ def test_Subtensor(x, indices):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)), (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
], ],
) )
def test_AdvancedSubtensor1(x, indices): def test_AdvancedSubtensor1(x, indices):
...@@ -69,20 +72,45 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -69,20 +72,45 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices, objmode_needed",
[ [
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(0, [1, 2, 2, 3]),
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(np.array([True, False, False])),
False,
),
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(bool)),
True,
),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]),
True,
), ),
], ],
) )
def test_AdvancedSubtensor(x, indices): @pytest.mark.filterwarnings("error")
def test_AdvancedSubtensor(x, indices, objmode_needed):
"""Test NumPy's advanced indexing in more than one dimension.""" """Test NumPy's advanced indexing in more than one dimension."""
out_pt = x[indices] out_pt = x[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor) assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt]) out_fg = FunctionGraph([], [out_pt])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSubtensor's perform method",
)
if objmode_needed
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
...@@ -194,34 +222,119 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -194,34 +222,119 @@ def test_AdvancedIncSubtensor1(x, y, indices):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, y, indices", "x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode",
[ [
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-np.arange(3 * 5).reshape(3, 5),
(slice(None, None, 2), [1, 2, 3]),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-99,
(slice(None, None, 2), [1, 2, 3], -1),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-99, # Broadcasted value
(slice(None, None, 2), [1, 2, 3]),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-np.arange(4 * 5).reshape(4, 5),
(0, [1, 2, 2, 3]),
True,
False,
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
[-99], # Broadcsasted value
(0, [1, 2, 2, 3]),
True,
False,
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])),
False,
False,
False,
),
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
-np.arange(3),
(np.eye(3).astype(bool)),
False,
True,
True,
),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 5))), as_tensor(rng.poisson(size=(2, 5))),
([1, 2], [2, 3]), ([1, 2], [2, 3]),
False,
True,
True,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 4))), as_tensor(rng.poisson(size=(2, 4))),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]),
False,
True,
True,
), ),
pytest.param( pytest.param(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 5))), as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]), ([1, 1], [2, 2]),
False,
True,
True,
), ),
], ],
) )
def test_AdvancedIncSubtensor(x, y, indices): @pytest.mark.filterwarnings("error")
def test_AdvancedIncSubtensor(
x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode
):
out_pt = set_subtensor(x[indices], y) out_pt = set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt]) out_fg = FunctionGraph([], [out_pt])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
)
if set_requires_objmode
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
out_pt = inc_subtensor(x[indices], y) out_pt = inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt]) out_fg = FunctionGraph([], [out_pt])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
)
if inc_requires_objmode
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
x_pt = x.type() x_pt = x.type()
...@@ -231,4 +344,12 @@ def test_AdvancedIncSubtensor(x, y, indices): ...@@ -231,4 +344,12 @@ def test_AdvancedIncSubtensor(x, y, indices):
out_pt.owner.op.inplace = True out_pt.owner.op.inplace = True
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) out_fg = FunctionGraph([x_pt], [out_pt])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
)
if set_requires_objmode
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [x.data]) compare_numba_and_py(out_fg, [x.data])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论