提交 05ea255f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement Numba slice boxing to enable `AdvancedSubtensor` with slices

上级 adc9ce96
......@@ -4,6 +4,9 @@ import numba
import numpy as np
import scipy
import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types
from numba.extending import box
from aesara.compile.ops import DeepCopyOp
from aesara.graph.fg import FunctionGraph
......@@ -15,6 +18,32 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from aesara.tensor.type_other import MakeSlice
def slice_new(self, start, stop, step):
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
return self.builder.call(fn, [start, stop, step])
@box(types.SliceType)
def box_slice(typ, val, c):
"""Implement boxing for ``slice`` objects in Numba.
This makes it possible to return an Numba's internal representation of a
``slice`` object as a proper ``slice`` to Python.
"""
start = c.box(types.int64, c.builder.extract_value(val, 0))
stop = c.box(types.int64, c.builder.extract_value(val, 1))
if typ.has_step:
step = c.box(types.int64, c.builder.extract_value(val, 2))
else:
step = c.pyapi.get_null_object()
slice_val = slice_new(c.pyapi, start, stop, step)
return slice_val
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
......@@ -199,9 +228,10 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
# XXX: This won't work when calling into object mode (e.g. for advanced
# indexing), because there's no Numba unboxing for its native `slice`
# objects.
"""
XXX: This requires a ``slice`` boxing implementation to work with Numba's
object mode.
"""
@numba.njit
def makeslice(*x):
......
......@@ -145,13 +145,10 @@ def test_AdvancedSubtensor1(x, indices):
"x, indices",
[
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
# XXX TODO: This will fail because advanced indexing calls into object
# mode (i.e. Python) and there's no unboxing for Numba's internal/native
# `slice` objects.
# (
# aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
# ([1, 2], slice(None), [3, 4]),
# ),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
),
],
)
def test_AdvancedSubtensor(x, indices):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论