提交 47874eb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Adapt Numba vectorize iterator for RandomVariables

上级 38c04c96
...@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs): ...@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs):
kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True)
# Supress caching warnings # Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals', message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
"as it uses dynamic globals"
),
category=NumbaWarning, category=NumbaWarning,
) )
......
...@@ -24,6 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( ...@@ -24,6 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options, _jit_options,
_vectorized, _vectorized,
encode_literals, encode_literals,
store_core_outputs,
) )
from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
...@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
**kwargs, **kwargs,
) )
nin = len(node.inputs)
nout = len(node.outputs)
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs]) output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
output_dtypes = tuple(out.type.dtype for out in node.outputs) output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items()) inplace_pattern = tuple(op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout))
# numba doesn't support nested literals right now... # numba doesn't support nested literals right now...
input_bc_patterns_enc = encode_literals(input_bc_patterns) input_bc_patterns_enc = encode_literals(input_bc_patterns)
...@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
def elemwise_wrapper(*inputs): def elemwise_wrapper(*inputs):
return _vectorized( return _vectorized(
scalar_op_fn, core_op_fn,
input_bc_patterns_enc, input_bc_patterns_enc,
output_bc_patterns_enc, output_bc_patterns_enc,
output_dtypes_enc, output_dtypes_enc,
inplace_pattern_enc, inplace_pattern_enc,
(), # constant_inputs
inputs, inputs,
core_output_shapes, # core_shapes
None, # size
) )
# Pure python implementation, that will be used in tests # Pure python implementation, that will be used in tests
......
...@@ -2,8 +2,9 @@ from __future__ import annotations ...@@ -2,8 +2,9 @@ from __future__ import annotations
import base64 import base64
import pickle import pickle
from collections.abc import Sequence from collections.abc import Callable, Sequence
from typing import Any from textwrap import indent
from typing import Any, cast
import numba import numba
import numpy as np import numpy as np
...@@ -11,13 +12,54 @@ from llvmlite import ir ...@@ -11,13 +12,54 @@ from llvmlite import ir
from numba import TypingError, types from numba import TypingError, types
from numba.core import cgutils from numba.core import cgutils
from numba.core.base import BaseContext from numba.core.base import BaseContext
from numba.core.types.misc import NoneType
from numba.np import arrayobj from numba.np import arrayobj
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.utils import compile_function_src
def encode_literals(literals: Sequence) -> str: def encode_literals(literals: Sequence) -> str:
return base64.encodebytes(pickle.dumps(literals)).decode() return base64.encodebytes(pickle.dumps(literals)).decode()
def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable:
"""Create a Numba function that wraps a core function and stores its vectorized outputs.
@njit
def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
to0, to1, ..., ton = core_op_fn(i0, i1, ..., in)
o0[...] = to0
o1[...] = to1
...
on[...] = ton
"""
inputs = [f"i{i}" for i in range(nin)]
outputs = [f"o{i}" for i in range(nout)]
inner_outputs = [f"t{output}" for output in outputs]
inp_signature = ", ".join(inputs)
out_signature = ", ".join(outputs)
inner_out_signature = ", ".join(inner_outputs)
store_outputs = "\n".join(
[
f"{output}[...] = {inner_output}"
for output, inner_output in zip(outputs, inner_outputs)
]
)
func_src = f"""
def store_core_outputs({inp_signature}, {out_signature}):
{inner_out_signature} = core_op_fn({inp_signature})
{indent(store_outputs, " " * 4)}
"""
global_env = {"core_op_fn": core_op_fn}
func = compile_function_src(
func_src, "store_core_outputs", {**globals(), **global_env}
)
return cast(Callable, numba_basic.numba_njit(func))
_jit_options = { _jit_options = {
"fastmath": { "fastmath": {
"arcp", # Allow Reciprocal "arcp", # Allow Reciprocal
...@@ -39,7 +81,10 @@ def _vectorized( ...@@ -39,7 +81,10 @@ def _vectorized(
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
inplace_pattern, inplace_pattern,
inputs, constant_inputs_types,
input_types,
output_core_shape_types,
size_type,
): ):
arg_types = [ arg_types = [
scalar_func, scalar_func,
...@@ -47,7 +92,10 @@ def _vectorized( ...@@ -47,7 +92,10 @@ def _vectorized(
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
inplace_pattern, inplace_pattern,
inputs, constant_inputs_types,
input_types,
output_core_shape_types,
size_type,
] ]
if not isinstance(input_bc_patterns, types.Literal): if not isinstance(input_bc_patterns, types.Literal):
...@@ -70,34 +118,82 @@ def _vectorized( ...@@ -70,34 +118,82 @@ def _vectorized(
inplace_pattern = inplace_pattern.literal_value inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
n_outputs = len(output_bc_patterns) batch_ndim = len(input_bc_patterns[0])
nin = len(constant_inputs_types) + len(input_types)
nout = len(output_bc_patterns)
if nin == 0:
raise TypingError("Empty argument list to vectorized op.")
if nout == 0:
raise TypingError("Empty list of outputs for vectorized op.")
if not len(inputs) > 0: if not all(isinstance(input, types.Array) for input in input_types):
raise TypingError("Empty argument list to elemwise op.") raise TypingError("Vectorized inputs must be arrays.")
if not n_outputs > 0: if not all(
raise TypingError("Empty list of outputs for elemwise op.") len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns
):
raise TypingError(
"Vectorized broadcastable patterns must have the same length."
)
core_input_types = []
for input_type, bc_pattern in zip(input_types, input_bc_patterns):
core_ndim = input_type.ndim - len(bc_pattern)
# TODO: Reconsider this
if core_ndim == 0:
core_input_type = input_type.dtype
else:
core_input_type = types.Array(
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
)
core_input_types.append(core_input_type)
if not all(isinstance(input, types.Array) for input in inputs): core_out_types = [
raise TypingError("Inputs to elemwise must be arrays.") types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C")
ndim = inputs[0].ndim for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types)
]
if not all(input.ndim == ndim for input in inputs): out_types = [
raise TypingError("Inputs to elemwise must have the same rank.") types.Array(
numba.from_dtype(np.dtype(dtype)), batch_ndim + len(output_core_shape), "C"
)
for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types)
]
if not all(len(pattern) == ndim for pattern in output_bc_patterns): for output_idx, input_idx in inplace_pattern:
raise TypingError("Invalid output broadcasting pattern.") output_type = input_types[input_idx]
core_out_types[output_idx] = types.Array(
dtype=output_type.dtype,
ndim=output_type.ndim - batch_ndim,
layout=input_type.layout,
)
out_types[output_idx] = output_type
scalar_signature = typingctx.resolve_function_type( core_signature = typingctx.resolve_function_type(
scalar_func, [in_type.dtype for in_type in inputs], {} scalar_func,
[
*constant_inputs_types,
*core_input_types,
*core_out_types,
],
{},
) )
ret_type = types.Tuple(out_types)
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
# So we can access the constant values in codegen... # So we can access the constant values in codegen...
input_bc_patterns_val = input_bc_patterns input_bc_patterns_val = input_bc_patterns
output_bc_patterns_val = output_bc_patterns output_bc_patterns_val = output_bc_patterns
output_dtypes_val = output_dtypes output_dtypes_val = output_dtypes
inplace_pattern_val = inplace_pattern inplace_pattern_val = inplace_pattern
input_types = inputs input_types = input_types
size_is_none = isinstance(size_type, NoneType)
def codegen( def codegen(
ctx, ctx,
...@@ -105,8 +201,16 @@ def _vectorized( ...@@ -105,8 +201,16 @@ def _vectorized(
sig, sig,
args, args,
): ):
[_, _, _, _, _, inputs] = args [_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
constant_inputs = cgutils.unpack_tuple(builder, constant_inputs)
inputs = cgutils.unpack_tuple(builder, inputs) inputs = cgutils.unpack_tuple(builder, inputs)
output_core_shapes = [
cgutils.unpack_tuple(builder, shape)
for shape in cgutils.unpack_tuple(builder, output_core_shapes)
]
size = None if size_is_none else cgutils.unpack_tuple(builder, size)
inputs = [ inputs = [
arrayobj.make_array(ty)(ctx, builder, val) arrayobj.make_array(ty)(ctx, builder, val)
for ty, val in zip(input_types, inputs) for ty, val in zip(input_types, inputs)
...@@ -118,6 +222,7 @@ def _vectorized( ...@@ -118,6 +222,7 @@ def _vectorized(
builder, builder,
in_shapes, in_shapes,
input_bc_patterns_val, input_bc_patterns_val,
size,
) )
outputs, output_types = make_outputs( outputs, output_types = make_outputs(
...@@ -129,6 +234,7 @@ def _vectorized( ...@@ -129,6 +234,7 @@ def _vectorized(
inplace_pattern_val, inplace_pattern_val,
inputs, inputs,
input_types, input_types,
output_core_shapes,
) )
make_loop_call( make_loop_call(
...@@ -136,8 +242,9 @@ def _vectorized( ...@@ -136,8 +242,9 @@ def _vectorized(
ctx, ctx,
builder, builder,
scalar_func, scalar_func,
scalar_signature, core_signature,
iter_shape, iter_shape,
constant_inputs,
inputs, inputs,
outputs, outputs,
input_bc_patterns_val, input_bc_patterns_val,
...@@ -162,69 +269,94 @@ def _vectorized( ...@@ -162,69 +269,94 @@ def _vectorized(
builder, sig.return_type, [out._getvalue() for out in outputs] builder, sig.return_type, [out._getvalue() for out in outputs]
) )
ret_types = [
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
]
for output_idx, input_idx in inplace_pattern:
ret_types[output_idx] = input_types[input_idx]
ret_type = types.Tuple(ret_types)
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
return sig, codegen return sig, codegen
def compute_itershape( def compute_itershape(
ctx: BaseContext, ctx: BaseContext,
builder: ir.IRBuilder, builder: ir.IRBuilder,
in_shapes: tuple[ir.Instruction, ...], in_shapes: list[list[ir.Instruction]],
broadcast_pattern: tuple[tuple[bool, ...], ...], broadcast_pattern: tuple[tuple[bool, ...], ...],
size: list[ir.Instruction] | None,
): ):
one = ir.IntType(64)(1) one = ir.IntType(64)(1)
ndim = len(in_shapes[0]) batch_ndim = len(broadcast_pattern[0])
shape = [None] * ndim shape = [None] * batch_ndim
for i in range(ndim): if size is not None:
for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): shape = size
length = in_shape[i] for i in range(batch_ndim):
if bc[i]: for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
with builder.if_then( length = in_shape[i]
builder.icmp_unsigned("!=", length, one), likely=False if bc[i]:
): with builder.if_then(
msg = ( builder.icmp_unsigned("!=", length, one), likely=False
f"Input {j} to elemwise is expected to have shape 1 in axis {i}" ):
) msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}"
ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
elif shape[i] is not None: else:
with builder.if_then( with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False builder.icmp_unsigned("!=", length, shape[i]), likely=False
): ):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( with builder.if_else(
then, builder.icmp_unsigned("==", length, one)
otherwise, ) as (
then,
otherwise,
):
with then:
msg = (
f"Incompatible vectorized shapes for input {j} and axis {i}. "
f"Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
)
ctx.call_conv.return_user_exc(
builder, ValueError, (msg,)
)
with otherwise:
msg = f"Vectorized input {j} has an incompatible shape in axis {i}."
ctx.call_conv.return_user_exc(
builder, ValueError, (msg,)
)
else:
# Size is implied by the broadcast pattern
for i in range(batch_ndim):
for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
length = in_shape[i]
if bc[i]:
with builder.if_then(
builder.icmp_unsigned("!=", length, one), likely=False
):
msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}"
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
elif shape[i] is not None:
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
): ):
with then: with builder.if_else(
msg = ( builder.icmp_unsigned("==", length, one)
f"Incompatible shapes for input {j} and axis {i} of " ) as (
f"elemwise. Input {j} has shape 1, but is not statically " then,
"known to have shape 1, and thus not broadcastable." otherwise,
) ):
ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) with then:
with otherwise: msg = (
msg = ( f"Incompatible vectorized shapes for input {j} and axis {i}. "
f"Input {j} to elemwise has an incompatible " f"Input {j} has shape 1, but is not statically "
f"shape in axis {i}." "known to have shape 1, and thus not broadcastable."
) )
ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) ctx.call_conv.return_user_exc(
else: builder, ValueError, (msg,)
shape[i] = length )
for i in range(ndim): with otherwise:
if shape[i] is None: msg = f"Vectorized input {j} has an incompatible shape in axis {i}."
shape[i] = one ctx.call_conv.return_user_exc(
builder, ValueError, (msg,)
)
else:
shape[i] = length
for i in range(batch_ndim):
if shape[i] is None:
shape[i] = one
return shape return shape
...@@ -237,27 +369,32 @@ def make_outputs( ...@@ -237,27 +369,32 @@ def make_outputs(
inplace: tuple[tuple[int, int], ...], inplace: tuple[tuple[int, int], ...],
inputs: tuple[Any, ...], inputs: tuple[Any, ...],
input_types: tuple[Any, ...], input_types: tuple[Any, ...],
): output_core_shapes: tuple,
arrays = [] ) -> tuple[list[ir.Value], list[types.Array]]:
ar_types: list[types.Array] = [] output_arrays = []
output_arry_types = []
one = ir.IntType(64)(1) one = ir.IntType(64)(1)
inplace_dict = dict(inplace) inplace_dict = dict(inplace)
for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)): for i, (core_shape, bc, dtype) in enumerate(
zip(output_core_shapes, out_bc, dtypes)
):
if i in inplace_dict: if i in inplace_dict:
arrays.append(inputs[inplace_dict[i]]) output_arrays.append(inputs[inplace_dict[i]])
ar_types.append(input_types[inplace_dict[i]]) output_arry_types.append(input_types[inplace_dict[i]])
# We need to incref once we return the inplace objects # We need to incref once we return the inplace objects
continue continue
dtype = numba.from_dtype(np.dtype(dtype)) dtype = numba.from_dtype(np.dtype(dtype))
arrtype = types.Array(dtype, len(iter_shape), "C") output_ndim = len(iter_shape) + len(core_shape)
ar_types.append(arrtype) arrtype = types.Array(dtype, output_ndim, "C")
output_arry_types.append(arrtype)
# This is actually an internal numba function, I guess we could # This is actually an internal numba function, I guess we could
# call `numba.nd.unsafe.ndarray` instead? # call `numba.nd.unsafe.ndarray` instead?
shape = [ batch_shape = [
length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc) length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc)
] ]
shape = batch_shape + core_shape
array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape)
arrays.append(array) output_arrays.append(array)
# If there is no inplace operation, we know that all output arrays # If there is no inplace operation, we know that all output arrays
# don't alias. Informing llvm can make it easier to vectorize. # don't alias. Informing llvm can make it easier to vectorize.
...@@ -265,7 +402,7 @@ def make_outputs( ...@@ -265,7 +402,7 @@ def make_outputs(
# The first argument is the output pointer # The first argument is the output pointer
arg = builder.function.args[0] arg = builder.function.args[0]
arg.add_attribute("noalias") arg.add_attribute("noalias")
return arrays, ar_types return output_arrays, output_arry_types
def make_loop_call( def make_loop_call(
...@@ -275,6 +412,7 @@ def make_loop_call( ...@@ -275,6 +412,7 @@ def make_loop_call(
scalar_func: Any, scalar_func: Any,
scalar_signature: types.FunctionType, scalar_signature: types.FunctionType,
iter_shape: tuple[ir.Instruction, ...], iter_shape: tuple[ir.Instruction, ...],
constant_inputs: tuple[ir.Instruction, ...],
inputs: tuple[ir.Instruction, ...], inputs: tuple[ir.Instruction, ...],
outputs: tuple[ir.Instruction, ...], outputs: tuple[ir.Instruction, ...],
input_bc: tuple[tuple[bool, ...], ...], input_bc: tuple[tuple[bool, ...], ...],
...@@ -283,18 +421,8 @@ def make_loop_call( ...@@ -283,18 +421,8 @@ def make_loop_call(
output_types: tuple[Any, ...], output_types: tuple[Any, ...],
): ):
safe = (False, False) safe = (False, False)
n_outputs = len(outputs)
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Extract shape and stride information from the array. n_outputs = len(outputs)
# For later use in the loop body to do the indexing
def extract_array(aryty, obj):
shape = cgutils.unpack_tuple(builder, obj.shape)
strides = cgutils.unpack_tuple(builder, obj.strides)
data = obj.data
layout = aryty.layout
return (data, shape, strides, layout)
# TODO I think this is better than the noalias attribute # TODO I think this is better than the noalias attribute
# for the input, but self_ref isn't supported in a released # for the input, but self_ref isn't supported in a released
...@@ -306,12 +434,6 @@ def make_loop_call( ...@@ -306,12 +434,6 @@ def make_loop_call(
# input_scope_set = mod.add_metadata([input_scope, output_scope]) # input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope]) # output_scope_set = mod.add_metadata([input_scope, output_scope])
inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs))
outputs = tuple(
extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs)
)
zero = ir.Constant(ir.IntType(64), 0) zero = ir.Constant(ir.IntType(64), 0)
# Setup loops and initialize accumulators for outputs # Setup loops and initialize accumulators for outputs
...@@ -338,69 +460,105 @@ def make_loop_call( ...@@ -338,69 +460,105 @@ def make_loop_call(
# Load values from input arrays # Load values from input arrays
input_vals = [] input_vals = []
for array_info, bc in zip(inputs, input_bc): for input, input_type, bc in zip(inputs, input_types, input_bc):
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] core_ndim = input_type.ndim - len(bc)
ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
val = builder.load(ptr) idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [
# val.set_metadata("alias.scope", input_scope_set) zero
# val.set_metadata("noalias", output_scope_set) ] * core_ndim
ptr = cgutils.get_item_pointer2(
context,
builder,
input.data,
cgutils.unpack_tuple(builder, input.shape),
cgutils.unpack_tuple(builder, input.strides),
input_type.layout,
idxs_bc,
*safe,
)
if core_ndim == 0:
# Retrive scalar item at index
val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("noalias", output_scope_set)
else:
# Retrieve array item at index
# This is a streamlined version of Numba's `GUArrayArg.load`
# TODO check layout arg!
core_arry_type = types.Array(
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
)
core_array = context.make_array(core_arry_type)(context, builder)
core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:]
core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:]
itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype))
context.populate_array(
core_array,
# TODO whey do we need to bitcast?
data=builder.bitcast(ptr, core_array.data.type),
shape=cgutils.pack_array(builder, core_shape),
strides=cgutils.pack_array(builder, core_strides),
itemsize=context.get_constant(types.intp, itemsize),
# TODO what is meminfo about?
meminfo=None,
)
val = core_array._getvalue()
input_vals.append(val) input_vals.append(val)
# Create output slices to pass to inner func
output_slices = []
for output, output_type, bc in zip(outputs, output_types, output_bc):
core_ndim = output_type.ndim - len(bc)
size_type = output.shape.type.element # type: ignore
output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore
output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [
zero
] * core_ndim
ptr = cgutils.get_item_pointer2(
context,
builder,
output.data, # type:ignore
output_shape,
output_strides,
output_type.layout,
idxs_bc,
*safe,
)
# Retrieve array item at index
# This is a streamlined version of Numba's `GUArrayArg.load`
core_arry_type = types.Array(
dtype=output_type.dtype, ndim=core_ndim, layout=output_type.layout
)
core_array = context.make_array(core_arry_type)(context, builder)
core_shape = output_shape[-core_ndim:] if core_ndim > 0 else []
core_strides = output_strides[-core_ndim:] if core_ndim > 0 else []
itemsize = context.get_abi_sizeof(context.get_data_type(output_type.dtype))
context.populate_array(
core_array,
# TODO whey do we need to bitcast?
data=builder.bitcast(ptr, core_array.data.type),
shape=cgutils.pack_array(builder, core_shape, ty=size_type),
strides=cgutils.pack_array(builder, core_strides, ty=size_type),
itemsize=context.get_constant(types.intp, itemsize),
# TODO what is meminfo about?
meminfo=None,
)
val = core_array._getvalue()
output_slices.append(val)
inner_codegen = context.get_function(scalar_func, scalar_signature) inner_codegen = context.get_function(scalar_func, scalar_signature)
if isinstance(scalar_signature.args[0], types.StarArgTuple | types.StarArgUniTuple): if isinstance(scalar_signature.args[0], types.StarArgTuple | types.StarArgUniTuple):
input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)] input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)]
output_values = inner_codegen(builder, input_vals)
if isinstance(scalar_signature.return_type, types.Tuple | types.UniTuple): inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices])
output_values = cgutils.unpack_tuple(builder, output_values)
func_output_types = scalar_signature.return_type.types
else:
output_values = [output_values]
func_output_types = [scalar_signature.return_type]
# Update output value or accumulators respectively
for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)):
if accu is not None:
load = builder.load(accu)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
new_value = builder.fadd(load, value)
builder.store(new_value, accu)
# TODO belongs to noalias scope
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
else:
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])]
ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc)
# store = builder.store(value, ptr)
value = context.cast(
builder, value, func_output_types[i], output_types[i].dtype
)
arrayobj.store_item(context, builder, output_types[i], value, ptr)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
# Close the loops and write accumulator values to the output arrays # Close the loops
for depth, loop in enumerate(loop_stack[::-1]): for depth, loop in enumerate(loop_stack[::-1]):
for output, (accu, accu_depth) in enumerate(output_accumulator):
if accu_depth == depth:
idxs_bc = [
zero if bc else idx for idx, bc in zip(idxs, output_bc[output])
]
ptr = cgutils.get_item_pointer2(
context, builder, *outputs[output], idxs_bc
)
load = builder.load(accu)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
# store = builder.store(load, ptr)
load = context.cast(
builder, load, func_output_types[output], output_types[output].dtype
)
arrayobj.store_item(context, builder, output_types[output], load, ptr)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
loop.__exit__(None, None, None) loop.__exit__(None, None, None)
return return
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论