提交 a9c52dd1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move numba subtensor functionality to its own module

上级 0353abeb
......@@ -11,5 +11,6 @@ import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.subtensor
# isort: on
......@@ -29,7 +29,6 @@ from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
compile_function_src,
fgraph_to_python,
unique_name_generator,
)
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
......@@ -38,14 +37,6 @@ from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst
......@@ -479,217 +470,6 @@ def numba_funcify_FunctionGraph(
)
def create_index_func(node, objmode=False):
"""Create a Python function that assembles and uses an index on an array."""
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return unique_names(rval)
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
f"{convert_indices(indices, entry.stop)}, "
f"{convert_indices(indices, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
else:
raise ValueError()
set_or_inc = isinstance(
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
)
index_start_idx = 1 + int(set_or_inc)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
)
if len(indices_creation_src) == 1:
indices_creation_src = f"indices = ({indices_creation_src[0]},)"
else:
indices_creation_src = ", ".join(indices_creation_src)
indices_creation_src = f"indices = ({indices_creation_src})"
if set_or_inc:
fn_name = "incsubtensor"
if node.op.inplace:
index_prologue = f"z = {input_names[0]}"
else:
index_prologue = f"z = np.copy({input_names[0]})"
if node.inputs[1].ndim == 0:
# TODO FIXME: This is a hack to get around a weird Numba typing
# issue. See https://github.com/numba/numba/issues/6000
y_name = f"{input_names[1]}.item()"
else:
y_name = input_names[1]
if node.op.set_instead_of_inc:
index_body = f"z[indices] = {y_name}"
else:
index_body = f"z[indices] += {y_name}"
else:
fn_name = "subtensor"
index_prologue = ""
index_body = f"z = {input_names[0]}[indices]"
if objmode:
output_var = node.outputs[0]
if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""
index_body = f"""
with objmode({output_sig}):
{index_body}
"""
subtensor_def_src = f"""
def {fn_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return np.asarray(z)
"""
return subtensor_def_src
@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedSubtensor."),
UserWarning,
)
subtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env}
)
return numba_njit(subtensor_fn, boundscheck=True)
@numba_funcify.register(IncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedIncSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
UserWarning,
)
incsubtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
)
return numba_njit(incsubtensor_fn, boundscheck=True)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
x, vals, idxs = node.inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
if set_instead_of_inc:
if broadcast:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] = core_val
return x
else:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else:
if broadcast:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] += core_val
return x
else:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] += val
return x
if inplace:
return advancedincsubtensor1_inplace
else:
@numba_njit
def advancedincsubtensor1(x, vals, idxs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
return advancedincsubtensor1
def deepcopyop(x):
return copy(x)
......
import warnings
import numba
import numpy as np
from pytensor.graph import Type
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
def create_index_func(node, objmode=False):
"""Create a Python function that assembles and uses an index on an array."""
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return unique_names(rval)
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
f"{convert_indices(indices, entry.stop)}, "
f"{convert_indices(indices, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
else:
raise ValueError()
set_or_inc = isinstance(
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
)
index_start_idx = 1 + int(set_or_inc)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
)
if len(indices_creation_src) == 1:
indices_creation_src = f"indices = ({indices_creation_src[0]},)"
else:
indices_creation_src = ", ".join(indices_creation_src)
indices_creation_src = f"indices = ({indices_creation_src})"
if set_or_inc:
fn_name = "incsubtensor"
if node.op.inplace:
index_prologue = f"z = {input_names[0]}"
else:
index_prologue = f"z = np.copy({input_names[0]})"
if node.inputs[1].ndim == 0:
# TODO FIXME: This is a hack to get around a weird Numba typing
# issue. See https://github.com/numba/numba/issues/6000
y_name = f"{input_names[1]}.item()"
else:
y_name = input_names[1]
if node.op.set_instead_of_inc:
index_body = f"z[indices] = {y_name}"
else:
index_body = f"z[indices] += {y_name}"
else:
fn_name = "subtensor"
index_prologue = ""
index_body = f"z = {input_names[0]}[indices]"
if objmode:
output_var = node.outputs[0]
if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""
index_body = f"""
with objmode({output_sig}):
{index_body}
"""
subtensor_def_src = f"""
def {fn_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return np.asarray(z)
"""
return subtensor_def_src
@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedSubtensor."),
UserWarning,
)
subtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env}
)
return numba_njit(subtensor_fn, boundscheck=True)
@numba_funcify.register(IncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedIncSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
UserWarning,
)
incsubtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
)
return numba_njit(incsubtensor_fn, boundscheck=True)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
x, vals, idxs = node.inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
if set_instead_of_inc:
if broadcast:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] = core_val
return x
else:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else:
if broadcast:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] += core_val
return x
else:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] += val
return x
if inplace:
return advancedincsubtensor1_inplace
else:
@numba_njit
def advancedincsubtensor1(x, vals, idxs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
return advancedincsubtensor1
......@@ -33,7 +33,6 @@ from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas
from pytensor.tensor import subtensor as pt_subtensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -366,218 +365,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected
@pytest.mark.parametrize(
"x, indices",
[
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1,)),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None)),
),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1, 2, 0)),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(1, 2), 1, slice(None)),
),
],
)
def test_Subtensor(x, indices):
"""Test NumPy's basic indexing."""
out_pt = x[indices]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, indices",
[
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
],
)
def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension."""
out_pt = pt_subtensor.advanced_subtensor1(x, *indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
def test_AdvancedSubtensor1_out_of_bounds():
out_pt = pt_subtensor.advanced_subtensor1(np.arange(3), [4])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
with pytest.raises(IndexError):
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, indices",
[
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
),
],
)
def test_AdvancedSubtensor(x, indices):
"""Test NumPy's advanced indexing in more than one dimension."""
out_pt = x[indices]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, y, indices",
[
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(np.array(10)),
(1,),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(4, 5))),
(slice(None)),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(np.array(10)),
(1, 2, 0),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(1, 5))),
(slice(1, 2), 1, slice(None)),
),
],
)
def test_IncSubtensor(x, y, indices):
out_pt = pt.set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
out_pt = pt.inc_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
x_pt = x.type()
out_pt = pt.set_subtensor(x_pt[indices], y, inplace=True)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data])
@pytest.mark.parametrize(
"x, y, indices",
[
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2],),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 1],),
),
# Broadcasting values
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(1, 4, 5))),
([0, 2, 0],),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(5,))),
([0, 2],),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=())),
([2, 0],),
),
(
pt.as_tensor(np.arange(5)),
pt.as_tensor(rng.poisson(size=())),
([2, 0],),
),
],
)
def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = pt_subtensor.advanced_set_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
out_pt = pt_subtensor.advanced_inc_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
# With symbolic inputs
x_pt = x.type()
y_pt = y.type()
out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
out_pt = pt_subtensor.AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices
)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
@pytest.mark.parametrize(
"x, y, indices",
[
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(2, 5))),
([1, 2], [2, 3]),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(2, 4))),
([1, 2], slice(None), [3, 4]),
),
pytest.param(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]),
),
],
)
def test_AdvancedIncSubtensor(x, y, indices):
out_pt = pt.set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
out_pt = pt.inc_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
x_pt = x.type()
out_pt = pt.set_subtensor(x_pt[indices], y)
# Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just
# hack it on here
out_pt.owner.op.inplace = True
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data])
@pytest.mark.parametrize(
"x, i",
[
......
import numpy as np
import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
advanced_inc_subtensor1,
advanced_set_subtensor1,
advanced_subtensor1,
inc_subtensor,
set_subtensor,
)
from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@pytest.mark.parametrize(
"x, indices",
[
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1,)),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None)),
),
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1, 2, 0)),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(1, 2), 1, slice(None)),
),
],
)
def test_Subtensor(x, indices):
"""Test NumPy's basic indexing."""
out_pt = x[indices]
assert isinstance(out_pt.owner.op, Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, indices",
[
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
],
)
def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension."""
out_pt = advanced_subtensor1(x, *indices)
assert isinstance(out_pt.owner.op, AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
def test_AdvancedSubtensor1_out_of_bounds():
out_pt = advanced_subtensor1(np.arange(3), [4])
assert isinstance(out_pt.owner.op, AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
with pytest.raises(IndexError):
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, indices",
[
(as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
),
],
)
def test_AdvancedSubtensor(x, indices):
"""Test NumPy's advanced indexing in more than one dimension."""
out_pt = x[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, y, indices",
[
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(np.array(10)),
(1,),
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(4, 5))),
(slice(None)),
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(np.array(10)),
(1, 2, 0),
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(1, 5))),
(slice(1, 2), 1, slice(None)),
),
],
)
def test_IncSubtensor(x, y, indices):
out_pt = set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
out_pt = inc_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
x_pt = x.type()
out_pt = set_subtensor(x_pt[indices], y, inplace=True)
assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data])
@pytest.mark.parametrize(
"x, y, indices",
[
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2],),
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 1],),
),
# Broadcasting values
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(1, 4, 5))),
([0, 2, 0],),
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(5,))),
([0, 2],),
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=())),
([2, 0],),
),
(
as_tensor(np.arange(5)),
as_tensor(rng.poisson(size=())),
([2, 0],),
),
],
)
def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = advanced_set_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
out_pt = advanced_inc_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
# With symbolic inputs
x_pt = x.type()
y_pt = y.type()
out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices
)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
@pytest.mark.parametrize(
"x, y, indices",
[
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 5))),
([1, 2], [2, 3]),
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 4))),
([1, 2], slice(None), [3, 4]),
),
pytest.param(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]),
),
],
)
def test_AdvancedIncSubtensor(x, y, indices):
out_pt = set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
out_pt = inc_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
x_pt = x.type()
out_pt = set_subtensor(x_pt[indices], y)
# Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just
# hack it on here
out_pt.owner.op.inplace = True
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论