提交 27c21cd6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test numba slice boxing and fix representation of None stop with negative step

上级 688d6883
...@@ -46,33 +46,43 @@ def enable_slice_boxing(): ...@@ -46,33 +46,43 @@ def enable_slice_boxing():
""" """
start = c.builder.extract_value(val, 0) start = c.builder.extract_value(val, 0)
stop = c.builder.extract_value(val, 1) stop = c.builder.extract_value(val, 1)
step = c.builder.extract_value(val, 2) if typ.has_step else None
# Numba uses sys.maxsize and -sys.maxsize-1 to represent None
# We want to use None in the Python representation
none_val = ir.Constant(ir.IntType(64), sys.maxsize) none_val = ir.Constant(ir.IntType(64), sys.maxsize)
neg_none_val = ir.Constant(ir.IntType(64), -sys.maxsize - 1)
none_obj = c.pyapi.get_null_object()
start_is_none = c.builder.icmp_signed("==", start, none_val)
start = c.builder.select( start = c.builder.select(
start_is_none, c.builder.icmp_signed("==", start, none_val),
c.pyapi.get_null_object(), none_obj,
c.box(types.int64, start), c.box(types.int64, start),
) )
stop_is_none = c.builder.icmp_signed("==", stop, none_val) # None stop is represented as neg_none_val when step is negative
if step is not None:
stop_none_val = c.builder.select(
c.builder.icmp_signed(">", step, ir.Constant(ir.IntType(64), 0)),
none_val,
neg_none_val,
)
else:
stop_none_val = none_val
stop = c.builder.select( stop = c.builder.select(
stop_is_none, c.builder.icmp_signed("==", stop, stop_none_val),
c.pyapi.get_null_object(), none_obj,
c.box(types.int64, stop), c.box(types.int64, stop),
) )
if typ.has_step: if step is not None:
step = c.builder.extract_value(val, 2)
step_is_none = c.builder.icmp_signed("==", step, none_val)
step = c.builder.select( step = c.builder.select(
step_is_none, c.builder.icmp_signed("==", step, none_val),
c.pyapi.get_null_object(), none_obj,
c.box(types.int64, step), c.box(types.int64, step),
) )
else: else:
step = c.pyapi.get_null_object() step = none_obj
slice_val = slice_new(c.pyapi, start, stop, step) slice_val = slice_new(c.pyapi, start, stop, step)
......
...@@ -3,7 +3,9 @@ import contextlib ...@@ -3,7 +3,9 @@ import contextlib
import numpy as np import numpy as np
import pytest import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import Mode, as_symbolic
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -24,6 +26,45 @@ from tests.link.numba.test_basic import compare_numba_and_py, numba_mode ...@@ -24,6 +26,45 @@ from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}")
@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}")
@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}")
def test_slice(start, stop, step):
x = ps.int64("x")
sym_slice = as_symbolic(
slice(
x if start == "x" else start,
x if stop == "x" else stop,
x if step == "x" else step,
)
)
no_opt_mode = Mode(linker="numba", optimizer=None)
evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode)
assert isinstance(evaled_slice, slice)
if start == "x":
assert evaled_slice.start == -5
elif start is None and (evaled_slice.step is None or evaled_slice.step > 0):
# Numba can convert to 0 (and sometimes does) in this case
assert evaled_slice.start in (None, 0)
else:
assert evaled_slice.start == start
if stop == "x":
assert evaled_slice.stop == -5
else:
assert evaled_slice.stop == stop
if step == "x":
assert evaled_slice.step == -5
elif step is None:
# Numba can convert to 1 (and sometimes does) in this case
assert evaled_slice.step in (None, 1)
else:
assert evaled_slice.step == step
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论