提交 7b311b65 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add basic Subtensor support for Numba

`AdvancedSubtensor` will call out to Python (i.e. Numba's object mode), because that's not currently supported in nopython mode. This causes problems with indices that contain slices, because Numba has its own internal representation of slices and no unboxing for that (i.e. conversion back to Python).
上级 acd23ab6
import ast
from functools import singledispatch from functools import singledispatch
from tempfile import NamedTemporaryFile
import numba import numba
import numpy as np
from aesara.compile.ops import DeepCopyOp
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
from aesara.scalar.basic import Composite, ScalarOp from aesara.scalar.basic import Composite, ScalarOp
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type_other import MakeSlice
@singledispatch @singledispatch
...@@ -43,7 +50,6 @@ def numba_funcify_FunctionGraph( ...@@ -43,7 +50,6 @@ def numba_funcify_FunctionGraph(
@numba_funcify.register(ScalarOp) @numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, **kwargs): def numba_funcify_ScalarOp(op, **kwargs):
import numpy as np
numpy_func = getattr(np, op.nfunc_spec[0]) numpy_func = getattr(np, op.nfunc_spec[0])
...@@ -59,9 +65,8 @@ def numba_funcify_ScalarOp(op, **kwargs): ...@@ -59,9 +65,8 @@ def numba_funcify_ScalarOp(op, **kwargs):
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, **kwargs): def numba_funcify_Elemwise(op, **kwargs):
scalar_op = op.scalar_op scalar_op = op.scalar_op
# TODO:Vectorize this # TODO: Vectorize this
return numba_funcify(scalar_op) return numba_funcify(scalar_op)
...@@ -74,3 +79,114 @@ def numba_funcify_Composite(op, vectorize=True, **kwargs): ...@@ -74,3 +79,114 @@ def numba_funcify_Composite(op, vectorize=True, **kwargs):
return numba_impl(*args)[0] return numba_impl(*args)[0]
return composite return composite
def create_index_func(node, idx_list, objmode=False):
"""Create a Python function that assembles and uses an index on an array."""
def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval.auto_name
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
f"{convert_indices(indices, entry.stop)}, "
f"{convert_indices(indices, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
else:
raise ValueError()
input_names = [v.auto_name for v in node.inputs]
op_indices = list(node.inputs[1:])
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(input_names[1:])
)
if len(indices_creation_src) == 1:
indices_creation_src = f"indices = ({indices_creation_src[0]},)"
else:
indices_creation_src = ", ".join(indices_creation_src)
indices_creation_src = f"indices = ({indices_creation_src})"
if objmode:
output_var = node.outputs[0]
output_sig = f"{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]"
index_body = f"""
with objmode(z="{output_sig}"):
z = {input_names[0]}[indices]
"""
else:
index_body = f"z = {input_names[0]}[indices]"
subtensor_def_src = f"""
def subtensor({", ".join(input_names)}):
{indices_creation_src}
{index_body}
return z
"""
return subtensor_def_src
@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
subtensor_def_src = create_index_func(
node, idx_list, objmode=isinstance(op, AdvancedSubtensor)
)
subtensor_def_ast = ast.parse(subtensor_def_src)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(subtensor_def_src.encode())
local_env = {}
mod_code = compile(subtensor_def_ast, filename, mode="exec")
exec(mod_code, {"objmode": numba.objmode}, local_env)
subtensor_def = local_env["subtensor"]
return numba.njit(subtensor_def)
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
# Scalars are apparently returned as actual Python scalar types and not
# NumPy scalars, so we need two separate Numba functions for each case.
if node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba.njit
def deepcopyop(x):
return x
else:
@numba.njit
def deepcopyop(x):
return x.copy()
return deepcopyop
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
# XXX: This won't work when calling into object mode (e.g. for advanced
# indexing), because there's no Numba unboxing for its native `slice`
# objects.
@numba.njit
def makeslice(*x):
return slice(*x)
return makeslice
from functools import partial
import numpy as np import numpy as np
import pytest
import aesara import aesara
import aesara.scalar.basic as aes import aesara.scalar.basic as aes
import aesara.tensor as aet import aesara.tensor as aet
from aesara.compile.function import function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query from aesara.graph.optdb import Query
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.tensor import subtensor as aet_subtensor
# from aesara.graph.fg import FunctionGraph opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = Query(include=["fusion"], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
def compare_numba_and_py(
fgraph,
inputs,
assert_fn=None,
):
"""Function to compare python graph output and Numba compiled output for testing equality
In the tests below computational graphs are defined in Aesara. These graphs are then passed to
this function which then compiles the graphs in both Numba and python, runs the calculation
in both and checks if the results are the same
Parameters
----------
fgraph: FunctionGraph
Aesara function Graph object
inputs: iter
Inputs for function graph
assert_fn: func, opt
Assert function used to check for equality between python and Numba. If not
provided uses np.testing.assert_allclose
"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_numba_fn = function(fn_inputs, fgraph.outputs, mode=numba_mode)
numba_res = aesara_numba_fn(*inputs)
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = aesara_py_fn(*inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(numba_res, py_res):
assert_fn(j, p)
else:
assert_fn(numba_res, py_res)
return numba_res
def test_Composite(): def test_Composite():
opts = Query(include=["fusion"], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
y = aet.vector("y") y = aet.vector("y")
x = aet.vector("x") x = aet.vector("x")
...@@ -59,3 +108,65 @@ def test_Composite(): ...@@ -59,3 +108,65 @@ def test_Composite():
) # Answer from Numba converted FunctionGraph ) # Answer from Numba converted FunctionGraph
assert np.array_equal(res, numba_res) assert np.array_equal(res, numba_res)
@pytest.mark.parametrize(
"x, indices",
[
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1,)),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None)),
),
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1, 2, 0)),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(1, 2), 1, slice(None)),
),
],
)
def test_Subtensors(x, indices):
"""Test NumPy's basic indexing."""
out_aet = x[indices]
assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, indices",
[
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None)),
),
],
)
def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension."""
out_aet = x[[1, 2]]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize(
"x, indices",
[
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
# XXX TODO: This will fail because advanced indexing calls into object
# mode (i.e. Python) and there's no unboxing for Numba's internal/native
# `slice` objects.
# (
# aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
# ([1, 2], slice(None), [3, 4]),
# ),
],
)
def test_AdvancedSubtensor(x, indices):
"""Test NumPy's advanced indexing in more than one dimension."""
out_aet = x[indices]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论