提交 9da643e8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move slice dispatcher functionality to subtensor.py

上级 131982de
import operator
import sys
import warnings
from copy import copy
from functools import singledispatch
......@@ -8,11 +6,10 @@ from textwrap import dedent
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from llvmlite import ir
from numba import types
from numba.core.errors import NumbaWarning, TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload
from numba.extending import overload
from pytensor import In, config
from pytensor.compile import NUMBA
......@@ -36,7 +33,7 @@ from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.sort import ArgSortOp, SortOp
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.tensor.type_other import NoneConst
def numba_njit(*args, fastmath=None, **kwargs):
......@@ -149,69 +146,6 @@ def create_numba_signature(
return numba.types.void(*input_types)
def slice_new(self, start, stop, step):
fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
return self.builder.call(fn, [start, stop, step])
def enable_slice_boxing():
"""Enable boxing for Numba's native ``slice``s.
TODO: this can be removed when https://github.com/numba/numba/pull/6939 is
merged and a release is made.
"""
@box(types.SliceType)
def box_slice(typ, val, c):
"""Implement boxing for ``slice`` objects in Numba.
This makes it possible to return an Numba's internal representation of a
``slice`` object as a proper ``slice`` to Python.
"""
start = c.builder.extract_value(val, 0)
stop = c.builder.extract_value(val, 1)
none_val = ir.Constant(ir.IntType(64), sys.maxsize)
start_is_none = c.builder.icmp_signed("==", start, none_val)
start = c.builder.select(
start_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, start),
)
stop_is_none = c.builder.icmp_signed("==", stop, none_val)
stop = c.builder.select(
stop_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, stop),
)
if typ.has_step:
step = c.builder.extract_value(val, 2)
step_is_none = c.builder.icmp_signed("==", step, none_val)
step = c.builder.select(
step_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, step),
)
else:
step = c.pyapi.get_null_object()
slice_val = slice_new(c.pyapi, start, stop, step)
return slice_val
@numba.extending.overload(operator.contains)
def in_seq_empty_tuple(x, y):
if isinstance(x, types.Tuple) and not x.types:
return lambda x, y: False
enable_slice_boxing()
def to_scalar(x):
return np.asarray(x).item()
......@@ -388,15 +322,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_njit
def makeslice(*x):
return slice(*x)
return makeslice
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba_njit
......
import operator
import sys
import numba
import numpy as np
from llvmlite import ir
from numba import types
from numba.core.pythonapi import box
from pytensor.graph import Type
from pytensor.link.numba.dispatch import numba_funcify
......@@ -14,7 +21,79 @@ from pytensor.tensor.subtensor import (
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
def slice_new(self, start, stop, step):
fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
return self.builder.call(fn, [start, stop, step])
def enable_slice_boxing():
"""Enable boxing for Numba's native ``slice``s.
TODO: this can be removed when https://github.com/numba/numba/pull/6939 is
merged and a release is made.
"""
@box(types.SliceType)
def box_slice(typ, val, c):
"""Implement boxing for ``slice`` objects in Numba.
This makes it possible to return an Numba's internal representation of a
``slice`` object as a proper ``slice`` to Python.
"""
start = c.builder.extract_value(val, 0)
stop = c.builder.extract_value(val, 1)
none_val = ir.Constant(ir.IntType(64), sys.maxsize)
start_is_none = c.builder.icmp_signed("==", start, none_val)
start = c.builder.select(
start_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, start),
)
stop_is_none = c.builder.icmp_signed("==", stop, none_val)
stop = c.builder.select(
stop_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, stop),
)
if typ.has_step:
step = c.builder.extract_value(val, 2)
step_is_none = c.builder.icmp_signed("==", step, none_val)
step = c.builder.select(
step_is_none,
c.pyapi.get_null_object(),
c.box(types.int64, step),
)
else:
step = c.pyapi.get_null_object()
slice_val = slice_new(c.pyapi, start, stop, step)
return slice_val
@numba.extending.overload(operator.contains)
def in_seq_empty_tuple(x, y):
if isinstance(x, types.Tuple) and not x.types:
return lambda x, y: False
enable_slice_boxing()
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_njit
def makeslice(*x):
return slice(*x)
return makeslice
@numba_funcify.register(Subtensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论