提交 05ea255f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement Numba slice boxing to enable `AdvancedSubtensor` with slices

上级 adc9ce96
...@@ -4,6 +4,9 @@ import numba ...@@ -4,6 +4,9 @@ import numba
import numpy as np import numpy as np
import scipy import scipy
import scipy.special import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types
from numba.extending import box
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -15,6 +18,32 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte ...@@ -15,6 +18,32 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice
def slice_new(self, start, stop, step):
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
return self.builder.call(fn, [start, stop, step])
@box(types.SliceType)
def box_slice(typ, val, c):
"""Implement boxing for ``slice`` objects in Numba.
This makes it possible to return an Numba's internal representation of a
``slice`` object as a proper ``slice`` to Python.
"""
start = c.box(types.int64, c.builder.extract_value(val, 0))
stop = c.box(types.int64, c.builder.extract_value(val, 1))
if typ.has_step:
step = c.box(types.int64, c.builder.extract_value(val, 2))
else:
step = c.pyapi.get_null_object()
slice_val = slice_new(c.pyapi, start, stop, step)
return slice_val
@singledispatch @singledispatch
def numba_typify(data, dtype=None, **kwargs): def numba_typify(data, dtype=None, **kwargs):
return data return data
...@@ -199,9 +228,10 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): ...@@ -199,9 +228,10 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
@numba_funcify.register(MakeSlice) @numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs): def numba_funcify_MakeSlice(op, **kwargs):
# XXX: This won't work when calling into object mode (e.g. for advanced """
# indexing), because there's no Numba unboxing for its native `slice` XXX: This requires a ``slice`` boxing implementation to work with Numba's
# objects. object mode.
"""
@numba.njit @numba.njit
def makeslice(*x): def makeslice(*x):
......
...@@ -145,13 +145,10 @@ def test_AdvancedSubtensor1(x, indices): ...@@ -145,13 +145,10 @@ def test_AdvancedSubtensor1(x, indices):
"x, indices", "x, indices",
[ [
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), (aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
# XXX TODO: This will fail because advanced indexing calls into object (
# mode (i.e. Python) and there's no unboxing for Numba's internal/native aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
# `slice` objects. ([1, 2], slice(None), [3, 4]),
# ( ),
# aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
# ([1, 2], slice(None), [3, 4]),
# ),
], ],
) )
def test_AdvancedSubtensor(x, indices): def test_AdvancedSubtensor(x, indices):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论