提交 f454836b authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add typing for some numba elemwise

上级 9a55cfe6
from typing import Any, List, Optional, Tuple
import numba import numba
import numpy as np import numpy as np
from llvmlite import ir from llvmlite import ir
...@@ -10,8 +12,8 @@ from numba.np import arrayobj ...@@ -10,8 +12,8 @@ from numba.np import arrayobj
def compute_itershape( def compute_itershape(
ctx: BaseContext, ctx: BaseContext,
builder: ir.IRBuilder, builder: ir.IRBuilder,
in_shapes, in_shapes: Tuple[ir.Instruction, ...],
broadcast_pattern, broadcast_pattern: Tuple[Tuple[bool, ...], ...],
): ):
one = ir.IntType(64)(1) one = ir.IntType(64)(1)
ndim = len(in_shapes[0]) ndim = len(in_shapes[0])
...@@ -59,16 +61,23 @@ def compute_itershape( ...@@ -59,16 +61,23 @@ def compute_itershape(
def make_outputs( def make_outputs(
ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types ctx: numba.core.base.BaseContext,
builder: ir.IRBuilder,
iter_shape: Tuple[ir.Instruction, ...],
out_bc: Tuple[Tuple[bool, ...], ...],
dtypes: Tuple[Any, ...],
inplace: Tuple[Tuple[int, int], ...],
inputs: Tuple[Any, ...],
input_types: Tuple[Any, ...],
): ):
arrays = [] arrays = []
ar_types: list[types.Array] = [] ar_types: list[types.Array] = []
one = ir.IntType(64)(1) one = ir.IntType(64)(1)
inplace = dict(inplace) inplace_dict = dict(inplace)
for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)): for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)):
if i in inplace: if i in inplace_dict:
arrays.append(inputs[inplace[i]]) arrays.append(inputs[inplace_dict[i]])
ar_types.append(input_types[inplace[i]]) ar_types.append(input_types[inplace_dict[i]])
# We need to incref once we return the inplace objects # We need to incref once we return the inplace objects
continue continue
dtype = numba.from_dtype(np.dtype(dtype)) dtype = numba.from_dtype(np.dtype(dtype))
...@@ -95,15 +104,15 @@ def make_loop_call( ...@@ -95,15 +104,15 @@ def make_loop_call(
typingctx, typingctx,
context: numba.core.base.BaseContext, context: numba.core.base.BaseContext,
builder: ir.IRBuilder, builder: ir.IRBuilder,
scalar_func, scalar_func: Any,
scalar_signature, scalar_signature: types.FunctionType,
iter_shape, iter_shape: Tuple[ir.Instruction, ...],
inputs, inputs: Tuple[ir.Instruction, ...],
outputs, outputs: Tuple[ir.Instruction, ...],
input_bc, input_bc: Tuple[Tuple[bool, ...], ...],
output_bc, output_bc: Tuple[Tuple[bool, ...], ...],
input_types, input_types: Tuple[Any, ...],
output_types, output_types: Tuple[Any, ...],
): ):
safe = (False, False) safe = (False, False)
n_outputs = len(outputs) n_outputs = len(outputs)
...@@ -142,15 +151,15 @@ def make_loop_call( ...@@ -142,15 +151,15 @@ def make_loop_call(
# input_scope_set = mod.add_metadata([input_scope, output_scope]) # input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope]) # output_scope_set = mod.add_metadata([input_scope, output_scope])
inputs = [ inputs = tuple(
extract_array(aryty, ary) extract_array(aryty, ary)
for aryty, ary in zip(input_types, inputs, strict=True) for aryty, ary in zip(input_types, inputs, strict=True)
] )
outputs = [ outputs = tuple(
extract_array(aryty, ary) extract_array(aryty, ary)
for aryty, ary in zip(output_types, outputs, strict=True) for aryty, ary in zip(output_types, outputs, strict=True)
] )
zero = ir.Constant(ir.IntType(64), 0) zero = ir.Constant(ir.IntType(64), 0)
...@@ -158,7 +167,9 @@ def make_loop_call( ...@@ -158,7 +167,9 @@ def make_loop_call(
# This part corresponds to opening the loops # This part corresponds to opening the loops
loop_stack = [] loop_stack = []
loops = [] loops = []
output_accumulator = [(None, None)] * n_outputs output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [
(None, None)
] * n_outputs
for dim, length in enumerate(iter_shape): for dim, length in enumerate(iter_shape):
# Find outputs that only have accumulations left # Find outputs that only have accumulations left
for output in range(n_outputs): for output in range(n_outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论