提交 f454836b authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add typing for some numba elemwise

上级 9a55cfe6
from typing import Any, List, Optional, Tuple
import numba
import numpy as np
from llvmlite import ir
......@@ -10,8 +12,8 @@ from numba.np import arrayobj
def compute_itershape(
ctx: BaseContext,
builder: ir.IRBuilder,
in_shapes,
broadcast_pattern,
in_shapes: Tuple[ir.Instruction, ...],
broadcast_pattern: Tuple[Tuple[bool, ...], ...],
):
one = ir.IntType(64)(1)
ndim = len(in_shapes[0])
......@@ -59,16 +61,23 @@ def compute_itershape(
def make_outputs(
ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types
ctx: numba.core.base.BaseContext,
builder: ir.IRBuilder,
iter_shape: Tuple[ir.Instruction, ...],
out_bc: Tuple[Tuple[bool, ...], ...],
dtypes: Tuple[Any, ...],
inplace: Tuple[Tuple[int, int], ...],
inputs: Tuple[Any, ...],
input_types: Tuple[Any, ...],
):
arrays = []
ar_types: list[types.Array] = []
one = ir.IntType(64)(1)
inplace = dict(inplace)
inplace_dict = dict(inplace)
for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)):
if i in inplace:
arrays.append(inputs[inplace[i]])
ar_types.append(input_types[inplace[i]])
if i in inplace_dict:
arrays.append(inputs[inplace_dict[i]])
ar_types.append(input_types[inplace_dict[i]])
# We need to incref once we return the inplace objects
continue
dtype = numba.from_dtype(np.dtype(dtype))
......@@ -95,15 +104,15 @@ def make_loop_call(
typingctx,
context: numba.core.base.BaseContext,
builder: ir.IRBuilder,
scalar_func,
scalar_signature,
iter_shape,
inputs,
outputs,
input_bc,
output_bc,
input_types,
output_types,
scalar_func: Any,
scalar_signature: types.FunctionType,
iter_shape: Tuple[ir.Instruction, ...],
inputs: Tuple[ir.Instruction, ...],
outputs: Tuple[ir.Instruction, ...],
input_bc: Tuple[Tuple[bool, ...], ...],
output_bc: Tuple[Tuple[bool, ...], ...],
input_types: Tuple[Any, ...],
output_types: Tuple[Any, ...],
):
safe = (False, False)
n_outputs = len(outputs)
......@@ -142,15 +151,15 @@ def make_loop_call(
# input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope])
inputs = [
inputs = tuple(
extract_array(aryty, ary)
for aryty, ary in zip(input_types, inputs, strict=True)
]
)
outputs = [
outputs = tuple(
extract_array(aryty, ary)
for aryty, ary in zip(output_types, outputs, strict=True)
]
)
zero = ir.Constant(ir.IntType(64), 0)
......@@ -158,7 +167,9 @@ def make_loop_call(
# This part corresponds to opening the loops
loop_stack = []
loops = []
output_accumulator = [(None, None)] * n_outputs
output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [
(None, None)
] * n_outputs
for dim, length in enumerate(iter_shape):
# Find outputs that only have accumulations left
for output in range(n_outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论