提交 69eb09ad authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Specialized numba sum impl

上级 35b20b57
import base64
import pickle
from functools import singledispatch
from numbers import Number
import pickle
from textwrap import indent
from typing import Any, Callable, Literal, Optional, Union
import base64
from typing import Any, Callable, Optional, Union
import numba
import numpy as np
from llvmlite import ir
from numba import TypingError, literal_unroll, types, literally
from numba import TypingError, types
from numba.core import cgutils
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.np import arrayobj
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
......@@ -18,6 +16,7 @@ from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import elemwise_codegen
from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
create_tuple_creator,
......@@ -25,8 +24,6 @@ from pytensor.link.numba.dispatch.basic import (
numba_njit,
use_optimized_cheap_pass,
)
from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper
from pytensor.link.numba.dispatch import elemwise_codegen
from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import (
AND,
......@@ -45,7 +42,7 @@ from pytensor.scalar.basic import (
from pytensor.scalar.basic import add as add_as
from pytensor.scalar.basic import scalar_maximum
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar
......@@ -376,8 +373,7 @@ def create_multiaxis_reducer(
careduce_def_src = f"""
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
#return np.asarray({var_name})
return {var_name}
return np.asarray({var_name})
"""
careduce_fn = compile_function_src(
......@@ -447,6 +443,7 @@ _jit_options = {
}
}
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized(
typingctx,
......@@ -490,7 +487,6 @@ def _vectorized(
inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
n_inputs = len(inputs)
n_outputs = len(output_bc_patterns)
if not len(inputs) > 0:
......@@ -531,7 +527,10 @@ def _vectorized(
[_, _, _, _, _, inputs] = args
inputs = cgutils.unpack_tuple(builder, inputs)
inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)]
inputs = [
arrayobj.make_array(ty)(ctx, builder, val)
for ty, val in zip(input_types, inputs)
]
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
iter_shape = elemwise_codegen.compute_itershape(
......@@ -586,14 +585,22 @@ def _vectorized(
return outputs[0]._getvalue()
for inplace_idx in dict(inplace_pattern):
ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value())
return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs])
ctx.nrt.incref(
builder,
sig.return_type.types[inplace_idx],
outputs[inplace_idx]._get_value(),
)
return ctx.make_tuple(
builder, sig.return_type, [out._getvalue() for out in outputs]
)
# TODO check inplace_pattern
ret_type = types.Tuple([
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
])
ret_type = types.Tuple(
[
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
]
)
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
......@@ -649,6 +656,40 @@ def numba_funcify_Elemwise(op, node, **kwargs):
return elemwise_wrapper
@numba_funcify.register(Sum)
def numba_funcify_Sum(op, node, **kwargs):
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
axes = list(axes)
ndim_input = node.inputs[0].ndim
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
if ndim_input == len(axes):
@numba_njit(fastmath=True)
def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype...
return np.asarray(array.sum()).astype(np_acc_dtype)
else:
@numba_njit(fastmath=True)
def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype...
return array.sum(axes).astype(np_acc_dtype)
return impl_sum
@numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
......
import numba
import numpy as np
from llvmlite import ir
from numba import types
from numba.np import arrayobj
from numba.core import cgutils
import numba
import numpy as np
from numba.np import arrayobj
def compute_itershape(
......@@ -35,7 +35,9 @@ def compute_itershape(
return shape
def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types):
def make_outputs(
ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types
):
arrays = []
ar_types: list[types.Array] = []
one = ir.IntType(64)(1)
......@@ -52,8 +54,7 @@ def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace
# This is actually an interal numba function, I guess we could
# call `numba.nd.unsafe.ndarray` instead?
shape = [
length if not bc_dim else one
for length, bc_dim in zip(iter_shape, bc)
length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc)
]
array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape)
arrays.append(array)
......@@ -84,7 +85,7 @@ def make_loop_call(
safe = (False, False)
n_outputs = len(outputs)
#context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Lower the code of the scalar function so that we can use it in the inner loop
# Caching is set to false to avoid a numba bug TODO ref?
......@@ -155,12 +156,8 @@ def make_loop_call(
# Load values from input arrays
input_vals = []
for array_info, bc in zip(inputs, input_bc, strict=True):
idxs_bc = [
zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)
]
ptr = cgutils.get_item_pointer2(
context, builder, *array_info, idxs_bc, *safe
)
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)]
ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("noalias", output_scope_set)
......@@ -193,12 +190,9 @@ def make_loop_call(
# store.set_metadata("noalias", input_scope_set)
else:
idxs_bc = [
zero if bc else idx
for idx, bc in zip(idxs, output_bc[i], strict=True)
zero if bc else idx for idx, bc in zip(idxs, output_bc[i], strict=True)
]
ptr = cgutils.get_item_pointer2(
context, builder, *outputs[i], idxs_bc
)
ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc)
# store = builder.store(value, ptr)
arrayobj.store_item(context, builder, output_types[i], value, ptr)
# store.set_metadata("alias.scope", output_scope_set)
......@@ -210,9 +204,7 @@ def make_loop_call(
if accu_depth == depth:
idxs_bc = [
zero if bc else idx
for idx, bc in zip(
idxs, output_bc[output], strict=True
)
for idx, bc in zip(idxs, output_bc[output], strict=True)
]
ptr = cgutils.get_item_pointer2(
context, builder, *outputs[output], idxs_bc
......@@ -221,9 +213,7 @@ def make_loop_call(
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
# store = builder.store(load, ptr)
arrayobj.store_item(
context, builder, output_types[output], load, ptr
)
arrayobj.store_item(context, builder, output_types[output], load, ptr)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
loop.__exit__(None, None, None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论