提交 9e79f3a0 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Initial version of llvm elemwise impl

上级 38dc6c9f
import inspect
from functools import singledispatch
from numbers import Number
import pickle
from textwrap import indent
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Literal, Optional, Union
import base64
import numba
import numpy as np
from llvmlite import ir
from numba import TypingError, literal_unroll, types, literally
from numba.core import cgutils
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.np import arrayobj
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor import config
......@@ -16,13 +22,12 @@ from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
create_tuple_creator,
numba_funcify,
numba_njit,
use_optimized_cheap_pass,
)
from pytensor.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper
from pytensor.link.numba.dispatch import elemwise_codegen
from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import (
AND,
OR,
......@@ -431,6 +436,170 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
return axis_apply_fn
_jit_options = {
"fastmath": {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
}
}
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized(
typingctx,
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
inputs,
):
#if not isinstance(scalar_func, types.Literal):
# raise TypingError("scalar func must be literal.")
#scalar_func = scalar_func.literal_value
arg_types = [
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
inputs,
]
if not isinstance(input_bc_patterns, types.Literal):
raise TypingError("input_bc_patterns must be literal.")
input_bc_patterns = input_bc_patterns.literal_value
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
if not isinstance(output_bc_patterns, types.Literal):
raise TypeError("output_bc_patterns must be literal.")
output_bc_patterns = output_bc_patterns.literal_value
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
if not isinstance(output_dtypes, types.Literal):
raise TypeError("output_dtypes must be literal.")
output_dtypes = output_dtypes.literal_value
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
if not isinstance(inplace_pattern, types.Literal):
raise TypeError("inplace_pattern must be literal.")
inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
n_inputs = len(inputs)
n_outputs = len(output_bc_patterns)
if not len(inputs) > 0:
raise TypingError("Empty argument list to elemwise op.")
if not n_outputs > 0:
raise TypingError("Empty list of outputs for elemwise op.")
if not all(isinstance(input, types.Array) for input in inputs):
raise TypingError("Inputs to elemwise must be arrays.")
ndim = inputs[0].ndim
if not all(input.ndim == ndim for input in inputs):
raise TypingError("Inputs to elemwise must have the same rank.")
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
raise TypingError("Invalid output broadcasting pattern.")
scalar_signature = typingctx.resolve_function_type(
scalar_func, [in_type.dtype for in_type in inputs], {}
)
# So we can access the constant values in codegen...
input_bc_patterns_val = input_bc_patterns
output_bc_patterns_val = output_bc_patterns
output_dtypes_val = output_dtypes
inplace_pattern_val = inplace_pattern
input_types = inputs
#assert not inplace_pattern_val
def codegen(
ctx,
builder,
sig,
args,
):
[_, _, _, _, _, inputs] = args
inputs = cgutils.unpack_tuple(builder, inputs)
inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)]
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
iter_shape = elemwise_codegen.compute_itershape(
ctx,
builder,
in_shapes,
input_bc_patterns_val,
)
outputs, output_types = elemwise_codegen.make_outputs(
ctx,
builder,
iter_shape,
output_bc_patterns_val,
output_dtypes_val,
inplace_pattern_val,
inputs,
input_types,
)
def _check_input_shapes(*_):
# TODO impl
return
_check_input_shapes(
ctx,
builder,
iter_shape,
inputs,
input_bc_patterns_val,
)
elemwise_codegen.make_loop_call(
typingctx,
ctx,
builder,
scalar_func,
scalar_signature,
iter_shape,
inputs,
outputs,
input_bc_patterns_val,
output_bc_patterns_val,
input_types,
output_types,
)
if len(outputs) == 1:
if inplace_pattern:
assert inplace_pattern[0][0] == 0
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
return outputs[0]._getvalue()
for inplace_idx in dict(inplace_pattern):
ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value())
return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs])
# TODO check inplace_pattern
ret_type = types.Tuple([
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
])
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
return sig, codegen
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
# Creating a new scalar node is more involved and unnecessary
......@@ -441,55 +610,42 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)
flags = {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
}
scalar_op_fn = numba_funcify(
op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
)
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src,
inplace_elemwise_fn_name,
{**globals(), **inplace_global_env},
)
return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)(
inplace_elemwise_fn
ndim = node.outputs[0].ndim
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
output_dtypes = tuple(variable.dtype for variable in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items())
# numba doesn't support nested literals right now...
input_bc_patterns = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
output_bc_patterns = base64.encodebytes(pickle.dumps(output_bc_patterns)).decode()
output_dtypes = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
inplace_pattern = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
@numba_njit
def elemwise_wrapper(*inputs):
return _vectorized(
scalar_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
inputs,
)
return elemwise_fn
return elemwise_wrapper
@numba_funcify.register(CAReduce)
......
from llvmlite import ir
from numba import types
from numba.np import arrayobj
from numba.core import cgutils
import numba
import numpy as np
def compute_itershape(
ctx,
builder: ir.IRBuilder,
in_shapes,
broadcast_pattern,
):
one = ir.IntType(64)(1)
ndim = len(in_shapes[0])
#shape = [ir.IntType(64)(1) for _ in range(ndim)]
shape = [None] * ndim
for i in range(ndim):
# TODO Error checking...
# What if all shapes are 0?
for bc, in_shape in zip(broadcast_pattern, in_shapes):
if bc[i]:
# TODO
# raise error if length != 1
pass
else:
# TODO
# if shape[i] is not None:
# raise Error if !=
shape[i] = in_shape[i]
for i in range(ndim):
if shape[i] is None:
shape[i] = one
return shape
def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types):
arrays = []
ar_types: list[types.Array] = []
one = ir.IntType(64)(1)
inplace = dict(inplace)
for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)):
if i in inplace:
arrays.append(inputs[inplace[i]])
ar_types.append(input_types[inplace[i]])
# We need to incref once we return the inplace objects
continue
dtype = numba.from_dtype(np.dtype(dtype))
arrtype = types.Array(dtype, len(iter_shape), "C")
ar_types.append(arrtype)
# This is actually an interal numba function, I guess we could
# call `numba.nd.unsafe.ndarray` instead?
shape = [
length if not bc_dim else one
for length, bc_dim in zip(iter_shape, bc)
]
array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape)
arrays.append(array)
# If there is no inplace operation, we know that all output arrays
# don't alias. Informing llvm can make it easier to vectorize.
if not inplace:
# The first argument is the output pointer
arg = builder.function.args[0]
arg.add_attribute("noalias")
return arrays, ar_types
def make_loop_call(
typingctx,
context: numba.core.base.BaseContext,
builder: ir.IRBuilder,
scalar_func,
scalar_signature,
iter_shape,
inputs,
outputs,
input_bc,
output_bc,
input_types,
output_types,
):
safe = (False, False)
n_outputs = len(outputs)
#context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Lower the code of the scalar function so that we can use it in the inner loop
# Caching is set to false to avoid a numba bug TODO ref?
inner_func = context.compile_subroutine(
builder,
# I don't quite understand why we need to access `dispatcher` here.
# The object does seem to be a dispatcher already? But it is missing
# attributes...
scalar_func.dispatcher,
scalar_signature,
caching=False,
)
inner = inner_func.fndesc
# Extract shape and stride information from the array.
# For later use in the loop body to do the indexing
def extract_array(aryty, obj):
shape = cgutils.unpack_tuple(builder, obj.shape)
strides = cgutils.unpack_tuple(builder, obj.strides)
data = obj.data
layout = aryty.layout
return (data, shape, strides, layout)
# TODO I think this is better than the noalias attribute
# for the input, but self_ref isn't supported in a released
# llvmlite version yet
# mod = builder.module
# domain = mod.add_metadata([], self_ref=True)
# input_scope = mod.add_metadata([domain], self_ref=True)
# output_scope = mod.add_metadata([domain], self_ref=True)
# input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope])
inputs = [
extract_array(aryty, ary)
for aryty, ary in zip(input_types, inputs, strict=True)
]
outputs = [
extract_array(aryty, ary)
for aryty, ary in zip(output_types, outputs, strict=True)
]
zero = ir.Constant(ir.IntType(64), 0)
# Setup loops and initialize accumulators for outputs
# This part corresponds to opening the loops
loop_stack = []
loops = []
output_accumulator = [(None, None)] * n_outputs
for dim, length in enumerate(iter_shape):
# Find outputs that only have accumulations left
for output in range(n_outputs):
if output_accumulator[output][0] is not None:
continue
if all(output_bc[output][dim:]):
value = outputs[output][0].type.pointee(0)
accu = cgutils.alloca_once_value(builder, value)
output_accumulator[output] = (accu, dim)
loop = cgutils.for_range(builder, length)
loop_stack.append(loop)
loops.append(loop.__enter__())
# Code in the inner most loop...
idxs = [loopval.index for loopval in loops]
# Load values from input arrays
input_vals = []
for array_info, bc in zip(inputs, input_bc, strict=True):
idxs_bc = [
zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)
]
ptr = cgutils.get_item_pointer2(
context, builder, *array_info, idxs_bc, *safe
)
val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("noalias", output_scope_set)
input_vals.append(val)
# Call scalar function
output_values = context.call_internal(
builder,
inner,
scalar_signature,
input_vals,
)
if isinstance(scalar_signature.return_type, types.Tuple):
output_values = cgutils.unpack_tuple(builder, output_values)
else:
output_values = [output_values]
# Update output value or accumulators respectively
for i, ((accu, _), value) in enumerate(
zip(output_accumulator, output_values, strict=True)
):
if accu is not None:
load = builder.load(accu)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
new_value = builder.fadd(load, value)
builder.store(new_value, accu)
# TODO belongs to noalias scope
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
else:
idxs_bc = [
zero if bc else idx
for idx, bc in zip(idxs, output_bc[i], strict=True)
]
ptr = cgutils.get_item_pointer2(
context, builder, *outputs[i], idxs_bc
)
# store = builder.store(value, ptr)
arrayobj.store_item(context, builder, output_types[i], value, ptr)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
# Close the loops and write accumulator values to the output arrays
for depth, loop in enumerate(loop_stack[::-1]):
for output, (accu, accu_depth) in enumerate(output_accumulator):
if accu_depth == depth:
idxs_bc = [
zero if bc else idx
for idx, bc in zip(
idxs, output_bc[output], strict=True
)
]
ptr = cgutils.get_item_pointer2(
context, builder, *outputs[output], idxs_bc
)
load = builder.load(accu)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
# store = builder.store(load, ptr)
arrayobj.store_item(
context, builder, output_types[output], load, ptr
)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
loop.__exit__(None, None, None)
return
from numba import njit, types
from numba.core import cgutils
from numba.extending import intrinsic
def tuple_mapper(item_map_func):
@intrinsic
def map_tuple(typingctx, *input_tuples):
signatures = [
typingctx.resolve_function_type(item_map_func, args, {})
for args in zip(*[in_type.types for in_type in input_tuples], strict=True)
]
output_type = types.Tuple([sig.return_type for sig in signatures])
signature = output_type(types.StarArgTuple(input_tuples))
def codegen(context, builder, signature, args):
(input_tuples,) = args
input_values = []
for val in cgutils.unpack_tuple(builder, input_tuples):
input_values.append(cgutils.unpack_tuple(builder, val))
mapped_values = []
for values, sig in zip(zip(*input_values), signatures, strict=True):
func = context.compile_subroutine(builder, item_map_func, sig)
output = context.call_internal(builder, func.fndesc, sig, values)
mapped_values.append(output)
return context.make_tuple(builder, output_type, mapped_values)
return signature, codegen
return map_tuple
@njit
def check_broadcasting(array, bcs, shape):
assert array.ndim == len(shape)
for bc, array_length, length in zip(bcs, array.shape, shape):
if bc:
assert array_length == 1
else:
assert array_length == length
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论