提交 bf8307dd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement *IncSubtensor* translations for Numba

上级 05ea255f
...@@ -14,10 +14,20 @@ from aesara.graph.type import Type ...@@ -14,10 +14,20 @@ from aesara.graph.type import Type
from aesara.link.utils import compile_function_src, fgraph_to_python from aesara.link.utils import compile_function_src, fgraph_to_python
from aesara.scalar.basic import Composite, ScalarOp from aesara.scalar.basic import Composite, ScalarOp
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def slice_new(self, start, stop, step): def slice_new(self, start, stop, step):
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New") fn = self._get_function(fnty, name="PySlice_New")
...@@ -135,7 +145,7 @@ def numba_funcify_Composite(op, vectorize=True, **kwargs): ...@@ -135,7 +145,7 @@ def numba_funcify_Composite(op, vectorize=True, **kwargs):
return composite return composite
def create_index_func(node, idx_list, objmode=False): def create_index_func(node, objmode=False):
"""Create a Python function that assembles and uses an index on an array.""" """Create a Python function that assembles and uses an index on an array."""
def convert_indices(indices, entry): def convert_indices(indices, entry):
...@@ -153,13 +163,19 @@ def create_index_func(node, idx_list, objmode=False): ...@@ -153,13 +163,19 @@ def create_index_func(node, idx_list, objmode=False):
else: else:
raise ValueError() raise ValueError()
set_or_inc = isinstance(
node.op, (IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor)
)
index_start_idx = 1 + int(set_or_inc)
input_names = [v.auto_name for v in node.inputs] input_names = [v.auto_name for v in node.inputs]
op_indices = list(node.inputs[1:]) op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
indices_creation_src = ( indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list) tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list if idx_list
else tuple(input_names[1:]) else tuple(input_names[index_start_idx:])
) )
if len(indices_creation_src) == 1: if len(indices_creation_src) == 1:
...@@ -168,18 +184,47 @@ def create_index_func(node, idx_list, objmode=False): ...@@ -168,18 +184,47 @@ def create_index_func(node, idx_list, objmode=False):
indices_creation_src = ", ".join(indices_creation_src) indices_creation_src = ", ".join(indices_creation_src)
indices_creation_src = f"indices = ({indices_creation_src})" indices_creation_src = f"indices = ({indices_creation_src})"
if set_or_inc:
fn_name = "incsubtensor"
if node.op.inplace:
index_prologue = f"z = {input_names[0]}"
else:
index_prologue = f"z = np.copy({input_names[0]})"
if node.inputs[1].ndim == 0:
# TODO FIXME: This is a hack to get around a weird Numba typing
# issue. See https://github.com/numba/numba/issues/6000
y_name = f"{input_names[1]}.item()"
else:
y_name = input_names[1]
if node.op.set_instead_of_inc:
index_body = f"z[indices] = {y_name}"
else:
index_body = f"z[indices] += {y_name}"
else:
fn_name = "subtensor"
index_prologue = ""
index_body = f"z = {input_names[0]}[indices]"
if objmode: if objmode:
output_var = node.outputs[0] output_var = node.outputs[0]
output_sig = f"{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]"
if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""
index_body = f""" index_body = f"""
with objmode(z="{output_sig}"): with objmode({output_sig}):
z = {input_names[0]}[indices] {index_body}
""" """
else:
index_body = f"z = {input_names[0]}[indices]"
subtensor_def_src = f""" subtensor_def_src = f"""
def subtensor({", ".join(input_names)}): def {fn_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src} {indices_creation_src}
{index_body} {index_body}
return z return z
...@@ -193,19 +238,35 @@ def subtensor({", ".join(input_names)}): ...@@ -193,19 +238,35 @@ def subtensor({", ".join(input_names)}):
@numba_funcify.register(AdvancedSubtensor1) @numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs): def numba_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
subtensor_def_src = create_index_func( subtensor_def_src = create_index_func(
node, idx_list, objmode=isinstance(op, AdvancedSubtensor) node, objmode=isinstance(op, AdvancedSubtensor)
) )
global_env = {} global_env = {"np": np, "objmode": numba.objmode}
global_env["objmode"] = numba.objmode
subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env) subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env)
return numba.njit(subtensor_fn) return numba.njit(subtensor_fn)
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_IncSubtensor(op, node, **kwargs):
incsubtensor_def_src = create_index_func(
node, objmode=isinstance(op, AdvancedIncSubtensor)
)
global_env = {"np": np, "objmode": numba.objmode}
incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", global_env
)
return numba.njit(incsubtensor_fn)
@numba_funcify.register(DeepCopyOp) @numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs): def numba_funcify_DeepCopyOp(op, node, **kwargs):
......
...@@ -53,10 +53,13 @@ def compare_numba_and_py( ...@@ -53,10 +53,13 @@ def compare_numba_and_py(
fn_inputs, fn_inputs,
fgraph.outputs, fgraph.outputs,
mode=numba_mode, mode=numba_mode,
accept_inplace=True,
) )
numba_res = aesara_numba_fn(*inputs) numba_res = aesara_numba_fn(*inputs)
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) aesara_py_fn = function(
fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True
)
py_res = aesara_py_fn(*inputs) py_res = aesara_py_fn(*inputs)
if len(fgraph.outputs) > 1: if len(fgraph.outputs) > 1:
...@@ -115,7 +118,7 @@ def test_numba_Composite(inputs, input_values): ...@@ -115,7 +118,7 @@ def test_numba_Composite(inputs, input_values):
), ),
], ],
) )
def test_Subtensors(x, indices): def test_Subtensor(x, indices):
"""Test NumPy's basic indexing.""" """Test NumPy's basic indexing."""
out_aet = x[indices] out_aet = x[indices]
assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor) assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor)
...@@ -157,3 +160,115 @@ def test_AdvancedSubtensor(x, indices): ...@@ -157,3 +160,115 @@ def test_AdvancedSubtensor(x, indices):
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet]) out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, y, indices",
[
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.array(10)),
(1,),
),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.random.poisson(size=(4, 5))),
(slice(None)),
),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.array(10)),
(1, 2, 0),
),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.random.poisson(size=(1, 5))),
(slice(1, 2), 1, slice(None)),
),
],
)
def test_IncSubtensor(x, y, indices):
out_aet = aet.set_subtensor(x[indices], y)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
out_aet = aet.inc_subtensor(x[indices], y)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
x_at = x.type()
out_aet = aet.set_subtensor(x_at[indices], y, inplace=True)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_at], [out_aet])
compare_numba_and_py(out_fg, [x.data])
@pytest.mark.parametrize(
"x, y, indices",
[
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.random.poisson(size=(2, 4, 5))),
([1, 2],),
),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.random.poisson(size=(2, 4, 5))),
([1, 2], slice(None)),
),
],
)
def test_AdvancedIncSubtensor1(x, y, indices):
out_aet = aet.set_subtensor(x[indices], y)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
out_aet = aet.inc_subtensor(x[indices], y)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
x_at = x.type()
out_aet = aet.set_subtensor(x_at[indices], y, inplace=True)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_at], [out_aet])
compare_numba_and_py(out_fg, [x.data])
@pytest.mark.parametrize(
"x, y, indices",
[
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.random.poisson(size=(2, 5))),
([1, 2], [2, 3]),
),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(np.random.poisson(size=(2, 4))),
([1, 2], slice(None), [3, 4]),
),
],
)
def test_AdvancedIncSubtensor(x, y, indices):
out_aet = aet.set_subtensor(x[indices], y)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
out_aet = aet.inc_subtensor(x[indices], y)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
x_at = x.type()
out_aet = aet.set_subtensor(x_at[indices], y)
# Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just
# hack it on here
out_aet.owner.op.inplace = True
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_at], [out_aet])
compare_numba_and_py(out_fg, [x.data])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论