提交 69eb09ad authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Specialized numba sum impl

上级 35b20b57
import base64
import pickle
from functools import singledispatch from functools import singledispatch
from numbers import Number from numbers import Number
import pickle
from textwrap import indent from textwrap import indent
from typing import Any, Callable, Literal, Optional, Union from typing import Any, Callable, Optional, Union
import base64
import numba import numba
import numpy as np import numpy as np
from llvmlite import ir from numba import TypingError, types
from numba import TypingError, literal_unroll, types, literally
from numba.core import cgutils from numba.core import cgutils
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.np import arrayobj from numba.np import arrayobj
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
...@@ -18,6 +16,7 @@ from pytensor import config ...@@ -18,6 +16,7 @@ from pytensor import config
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import elemwise_codegen
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_numba_signature, create_numba_signature,
create_tuple_creator, create_tuple_creator,
...@@ -25,8 +24,6 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -25,8 +24,6 @@ from pytensor.link.numba.dispatch.basic import (
numba_njit, numba_njit,
use_optimized_cheap_pass, use_optimized_cheap_pass,
) )
from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper
from pytensor.link.numba.dispatch import elemwise_codegen
from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
AND, AND,
...@@ -45,7 +42,7 @@ from pytensor.scalar.basic import ( ...@@ -45,7 +42,7 @@ from pytensor.scalar.basic import (
from pytensor.scalar.basic import add as add_as from pytensor.scalar.basic import add as add_as
from pytensor.scalar.basic import scalar_maximum from pytensor.scalar.basic import scalar_maximum
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar from pytensor.tensor.type import scalar
...@@ -376,8 +373,7 @@ def create_multiaxis_reducer( ...@@ -376,8 +373,7 @@ def create_multiaxis_reducer(
careduce_def_src = f""" careduce_def_src = f"""
def {careduce_fn_name}({input_name}): def {careduce_fn_name}({input_name}):
{careduce_assign_lines} {careduce_assign_lines}
#return np.asarray({var_name}) return np.asarray({var_name})
return {var_name}
""" """
careduce_fn = compile_function_src( careduce_fn = compile_function_src(
...@@ -447,6 +443,7 @@ _jit_options = { ...@@ -447,6 +443,7 @@ _jit_options = {
} }
} }
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized( def _vectorized(
typingctx, typingctx,
...@@ -490,7 +487,6 @@ def _vectorized( ...@@ -490,7 +487,6 @@ def _vectorized(
inplace_pattern = inplace_pattern.literal_value inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
n_inputs = len(inputs)
n_outputs = len(output_bc_patterns) n_outputs = len(output_bc_patterns)
if not len(inputs) > 0: if not len(inputs) > 0:
...@@ -531,7 +527,10 @@ def _vectorized( ...@@ -531,7 +527,10 @@ def _vectorized(
[_, _, _, _, _, inputs] = args [_, _, _, _, _, inputs] = args
inputs = cgutils.unpack_tuple(builder, inputs) inputs = cgutils.unpack_tuple(builder, inputs)
inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)] inputs = [
arrayobj.make_array(ty)(ctx, builder, val)
for ty, val in zip(input_types, inputs)
]
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
iter_shape = elemwise_codegen.compute_itershape( iter_shape = elemwise_codegen.compute_itershape(
...@@ -586,14 +585,22 @@ def _vectorized( ...@@ -586,14 +585,22 @@ def _vectorized(
return outputs[0]._getvalue() return outputs[0]._getvalue()
for inplace_idx in dict(inplace_pattern): for inplace_idx in dict(inplace_pattern):
ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value()) ctx.nrt.incref(
return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs]) builder,
sig.return_type.types[inplace_idx],
outputs[inplace_idx]._get_value(),
)
return ctx.make_tuple(
builder, sig.return_type, [out._getvalue() for out in outputs]
)
# TODO check inplace_pattern # TODO check inplace_pattern
ret_type = types.Tuple([ ret_type = types.Tuple(
[
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes for dtype in output_dtypes
]) ]
)
if len(output_dtypes) == 1: if len(output_dtypes) == 1:
ret_type = ret_type.types[0] ret_type = ret_type.types[0]
sig = ret_type(*arg_types) sig = ret_type(*arg_types)
...@@ -649,6 +656,40 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -649,6 +656,40 @@ def numba_funcify_Elemwise(op, node, **kwargs):
return elemwise_wrapper return elemwise_wrapper
@numba_funcify.register(Sum)
def numba_funcify_Sum(op, node, **kwargs):
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
axes = list(axes)
ndim_input = node.inputs[0].ndim
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
if ndim_input == len(axes):
@numba_njit(fastmath=True)
def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype...
return np.asarray(array.sum()).astype(np_acc_dtype)
else:
@numba_njit(fastmath=True)
def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype...
return array.sum(axes).astype(np_acc_dtype)
return impl_sum
@numba_funcify.register(CAReduce) @numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs): def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis axes = op.axis
......
import numba
import numpy as np
from llvmlite import ir from llvmlite import ir
from numba import types from numba import types
from numba.np import arrayobj
from numba.core import cgutils from numba.core import cgutils
import numba from numba.np import arrayobj
import numpy as np
def compute_itershape( def compute_itershape(
...@@ -35,7 +35,9 @@ def compute_itershape( ...@@ -35,7 +35,9 @@ def compute_itershape(
return shape return shape
def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types): def make_outputs(
ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types
):
arrays = [] arrays = []
ar_types: list[types.Array] = [] ar_types: list[types.Array] = []
one = ir.IntType(64)(1) one = ir.IntType(64)(1)
...@@ -52,8 +54,7 @@ def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace ...@@ -52,8 +54,7 @@ def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace
# This is actually an interal numba function, I guess we could # This is actually an interal numba function, I guess we could
# call `numba.nd.unsafe.ndarray` instead? # call `numba.nd.unsafe.ndarray` instead?
shape = [ shape = [
length if not bc_dim else one length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc)
for length, bc_dim in zip(iter_shape, bc)
] ]
array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape)
arrays.append(array) arrays.append(array)
...@@ -84,7 +85,7 @@ def make_loop_call( ...@@ -84,7 +85,7 @@ def make_loop_call(
safe = (False, False) safe = (False, False)
n_outputs = len(outputs) n_outputs = len(outputs)
#context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) # context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Lower the code of the scalar function so that we can use it in the inner loop # Lower the code of the scalar function so that we can use it in the inner loop
# Caching is set to false to avoid a numba bug TODO ref? # Caching is set to false to avoid a numba bug TODO ref?
...@@ -155,12 +156,8 @@ def make_loop_call( ...@@ -155,12 +156,8 @@ def make_loop_call(
# Load values from input arrays # Load values from input arrays
input_vals = [] input_vals = []
for array_info, bc in zip(inputs, input_bc, strict=True): for array_info, bc in zip(inputs, input_bc, strict=True):
idxs_bc = [ idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)]
zero if bc else idx for idx, bc in zip(idxs, bc, strict=True) ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
]
ptr = cgutils.get_item_pointer2(
context, builder, *array_info, idxs_bc, *safe
)
val = builder.load(ptr) val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set) # val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("noalias", output_scope_set) # val.set_metadata("noalias", output_scope_set)
...@@ -193,12 +190,9 @@ def make_loop_call( ...@@ -193,12 +190,9 @@ def make_loop_call(
# store.set_metadata("noalias", input_scope_set) # store.set_metadata("noalias", input_scope_set)
else: else:
idxs_bc = [ idxs_bc = [
zero if bc else idx zero if bc else idx for idx, bc in zip(idxs, output_bc[i], strict=True)
for idx, bc in zip(idxs, output_bc[i], strict=True)
] ]
ptr = cgutils.get_item_pointer2( ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc)
context, builder, *outputs[i], idxs_bc
)
# store = builder.store(value, ptr) # store = builder.store(value, ptr)
arrayobj.store_item(context, builder, output_types[i], value, ptr) arrayobj.store_item(context, builder, output_types[i], value, ptr)
# store.set_metadata("alias.scope", output_scope_set) # store.set_metadata("alias.scope", output_scope_set)
...@@ -210,9 +204,7 @@ def make_loop_call( ...@@ -210,9 +204,7 @@ def make_loop_call(
if accu_depth == depth: if accu_depth == depth:
idxs_bc = [ idxs_bc = [
zero if bc else idx zero if bc else idx
for idx, bc in zip( for idx, bc in zip(idxs, output_bc[output], strict=True)
idxs, output_bc[output], strict=True
)
] ]
ptr = cgutils.get_item_pointer2( ptr = cgutils.get_item_pointer2(
context, builder, *outputs[output], idxs_bc context, builder, *outputs[output], idxs_bc
...@@ -221,9 +213,7 @@ def make_loop_call( ...@@ -221,9 +213,7 @@ def make_loop_call(
# load.set_metadata("alias.scope", output_scope_set) # load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set) # load.set_metadata("noalias", input_scope_set)
# store = builder.store(load, ptr) # store = builder.store(load, ptr)
arrayobj.store_item( arrayobj.store_item(context, builder, output_types[output], load, ptr)
context, builder, output_types[output], load, ptr
)
# store.set_metadata("alias.scope", output_scope_set) # store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set) # store.set_metadata("noalias", input_scope_set)
loop.__exit__(None, None, None) loop.__exit__(None, None, None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论