提交 f05a185b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba SliceLiteral lowering support

上级 ea146df8
......@@ -146,6 +146,93 @@ def to_scalar(x):
raise TypingError(f"{x} must be a scalar compatible type.")
def enable_slice_literals():
"""Enable lowering for ``SliceLiteral``s.
TODO: This can be removed once https://github.com/numba/numba/pull/6996 is merged
and a release is made.
"""
from numba.core import types
from numba.core.datamodel.models import SliceModel
from numba.core.datamodel.registry import register_default
from numba.core.imputils import lower_cast, lower_constant
from numba.core.types.misc import SliceLiteral
from numba.cpython.slicing import get_defaults
register_default(numba.types.misc.SliceLiteral)(SliceModel)
@property
def key(self):
return self.name
SliceLiteral.key = key
def make_slice_from_constant(context, builder, ty, pyval):
sli = context.make_helper(builder, ty)
lty = context.get_value_type(types.intp)
(
default_start_pos,
default_start_neg,
default_stop_pos,
default_stop_neg,
default_step,
) = [context.get_constant(types.intp, x) for x in get_defaults(context)]
step = pyval.step
if step is None:
step_is_neg = False
step = default_step
else:
step_is_neg = step < 0
step = lty(step)
start = pyval.start
if start is None:
if step_is_neg:
start = default_start_neg
else:
start = default_start_pos
else:
start = lty(start)
stop = pyval.stop
if stop is None:
if step_is_neg:
stop = default_stop_neg
else:
stop = default_stop_pos
else:
stop = lty(stop)
sli.start = start
sli.stop = stop
sli.step = step
return sli._getvalue()
@lower_constant(numba.types.SliceType)
def constant_slice(context, builder, ty, pyval):
if isinstance(ty, types.Literal):
typ = ty.literal_type
else:
typ = ty
return make_slice_from_constant(context, builder, typ, pyval)
@lower_cast(numba.types.misc.SliceLiteral, numba.types.SliceType)
def cast_from_literal(context, builder, fromty, toty, val):
return make_slice_from_constant(
context,
builder,
toty,
fromty.literal_value,
)
enable_slice_literals()
def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论