提交 c0a4276d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move vectorize wrapper to vectorize_codegen

上级 89e9bd6b
...@@ -8,17 +8,13 @@ from typing import Any ...@@ -8,17 +8,13 @@ from typing import Any
import numba import numba
import numpy as np import numpy as np
from numba import TypingError, types
from numba.core import cgutils
from numba.core.extending import overload from numba.core.extending import overload
from numba.np import arrayobj
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor import config from pytensor import config
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import vectorize_codegen
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_numba_signature, create_numba_signature,
create_tuple_creator, create_tuple_creator,
...@@ -26,6 +22,7 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -26,6 +22,7 @@ from pytensor.link.numba.dispatch.basic import (
numba_njit, numba_njit,
use_optimized_cheap_pass, use_optimized_cheap_pass,
) )
from pytensor.link.numba.dispatch.vectorize_codegen import _jit_options, _vectorized
from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
AND, AND,
...@@ -463,167 +460,6 @@ def create_axis_apply_fn(fn, axis, ndim, dtype): ...@@ -463,167 +460,6 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
return axis_apply_fn return axis_apply_fn
_jit_options = {
"fastmath": {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
},
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
}
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized(
typingctx,
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
inputs,
):
arg_types = [
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
inputs,
]
if not isinstance(input_bc_patterns, types.Literal):
raise TypingError("input_bc_patterns must be literal.")
input_bc_patterns = input_bc_patterns.literal_value
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
if not isinstance(output_bc_patterns, types.Literal):
raise TypeError("output_bc_patterns must be literal.")
output_bc_patterns = output_bc_patterns.literal_value
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
if not isinstance(output_dtypes, types.Literal):
raise TypeError("output_dtypes must be literal.")
output_dtypes = output_dtypes.literal_value
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
if not isinstance(inplace_pattern, types.Literal):
raise TypeError("inplace_pattern must be literal.")
inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
n_outputs = len(output_bc_patterns)
if not len(inputs) > 0:
raise TypingError("Empty argument list to elemwise op.")
if not n_outputs > 0:
raise TypingError("Empty list of outputs for elemwise op.")
if not all(isinstance(input, types.Array) for input in inputs):
raise TypingError("Inputs to elemwise must be arrays.")
ndim = inputs[0].ndim
if not all(input.ndim == ndim for input in inputs):
raise TypingError("Inputs to elemwise must have the same rank.")
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
raise TypingError("Invalid output broadcasting pattern.")
scalar_signature = typingctx.resolve_function_type(
scalar_func, [in_type.dtype for in_type in inputs], {}
)
# So we can access the constant values in codegen...
input_bc_patterns_val = input_bc_patterns
output_bc_patterns_val = output_bc_patterns
output_dtypes_val = output_dtypes
inplace_pattern_val = inplace_pattern
input_types = inputs
def codegen(
ctx,
builder,
sig,
args,
):
[_, _, _, _, _, inputs] = args
inputs = cgutils.unpack_tuple(builder, inputs)
inputs = [
arrayobj.make_array(ty)(ctx, builder, val)
for ty, val in zip(input_types, inputs)
]
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
iter_shape = vectorize_codegen.compute_itershape(
ctx,
builder,
in_shapes,
input_bc_patterns_val,
)
outputs, output_types = vectorize_codegen.make_outputs(
ctx,
builder,
iter_shape,
output_bc_patterns_val,
output_dtypes_val,
inplace_pattern_val,
inputs,
input_types,
)
vectorize_codegen.make_loop_call(
typingctx,
ctx,
builder,
scalar_func,
scalar_signature,
iter_shape,
inputs,
outputs,
input_bc_patterns_val,
output_bc_patterns_val,
input_types,
output_types,
)
if len(outputs) == 1:
if inplace_pattern:
assert inplace_pattern[0][0] == 0
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
return outputs[0]._getvalue()
for inplace_idx in dict(inplace_pattern):
ctx.nrt.incref(
builder,
sig.return_type.types[inplace_idx],
outputs[inplace_idx]._get_value(),
)
return ctx.make_tuple(
builder, sig.return_type, [out._getvalue() for out in outputs]
)
ret_types = [
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
]
for output_idx, input_idx in inplace_pattern:
ret_types[output_idx] = input_types[input_idx]
ret_type = types.Tuple(ret_types)
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
return sig, codegen
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs): def numba_funcify_Elemwise(op, node, **kwargs):
# Creating a new scalar node is more involved and unnecessary # Creating a new scalar node is more involved and unnecessary
...@@ -634,16 +470,12 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -634,16 +470,12 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs) scalar_node = op.scalar_op.make_node(*scalar_inputs)
flags = {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
}
scalar_op_fn = numba_funcify( scalar_op_fn = numba_funcify(
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs op.scalar_op,
node=scalar_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
) )
ndim = node.outputs[0].ndim ndim = node.outputs[0].ndim
...@@ -700,14 +532,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -700,14 +532,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
return tuple(outputs_summed) return tuple(outputs_summed)
return outputs_summed[0] return outputs_summed[0]
@overload( @overload(elemwise, jit_options=_jit_options)
elemwise,
jit_options={
"fastmath": flags,
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
},
)
def ov_elemwise(*inputs): def ov_elemwise(*inputs):
return elemwise_wrapper return elemwise_wrapper
......
from __future__ import annotations from __future__ import annotations
import base64
import pickle
from typing import Any from typing import Any
import numba import numba
import numpy as np import numpy as np
from llvmlite import ir from llvmlite import ir
from numba import types from numba import TypingError, types
from numba.core import cgutils from numba.core import cgutils
from numba.core.base import BaseContext from numba.core.base import BaseContext
from numba.np import arrayobj from numba.np import arrayobj
_jit_options = {
"fastmath": {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
},
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
}
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized(
typingctx,
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
inputs,
):
arg_types = [
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
inputs,
]
if not isinstance(input_bc_patterns, types.Literal):
raise TypingError("input_bc_patterns must be literal.")
input_bc_patterns = input_bc_patterns.literal_value
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
if not isinstance(output_bc_patterns, types.Literal):
raise TypeError("output_bc_patterns must be literal.")
output_bc_patterns = output_bc_patterns.literal_value
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
if not isinstance(output_dtypes, types.Literal):
raise TypeError("output_dtypes must be literal.")
output_dtypes = output_dtypes.literal_value
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
if not isinstance(inplace_pattern, types.Literal):
raise TypeError("inplace_pattern must be literal.")
inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
n_outputs = len(output_bc_patterns)
if not len(inputs) > 0:
raise TypingError("Empty argument list to elemwise op.")
if not n_outputs > 0:
raise TypingError("Empty list of outputs for elemwise op.")
if not all(isinstance(input, types.Array) for input in inputs):
raise TypingError("Inputs to elemwise must be arrays.")
ndim = inputs[0].ndim
if not all(input.ndim == ndim for input in inputs):
raise TypingError("Inputs to elemwise must have the same rank.")
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
raise TypingError("Invalid output broadcasting pattern.")
scalar_signature = typingctx.resolve_function_type(
scalar_func, [in_type.dtype for in_type in inputs], {}
)
# So we can access the constant values in codegen...
input_bc_patterns_val = input_bc_patterns
output_bc_patterns_val = output_bc_patterns
output_dtypes_val = output_dtypes
inplace_pattern_val = inplace_pattern
input_types = inputs
def codegen(
ctx,
builder,
sig,
args,
):
[_, _, _, _, _, inputs] = args
inputs = cgutils.unpack_tuple(builder, inputs)
inputs = [
arrayobj.make_array(ty)(ctx, builder, val)
for ty, val in zip(input_types, inputs)
]
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
iter_shape = compute_itershape(
ctx,
builder,
in_shapes,
input_bc_patterns_val,
)
outputs, output_types = make_outputs(
ctx,
builder,
iter_shape,
output_bc_patterns_val,
output_dtypes_val,
inplace_pattern_val,
inputs,
input_types,
)
make_loop_call(
typingctx,
ctx,
builder,
scalar_func,
scalar_signature,
iter_shape,
inputs,
outputs,
input_bc_patterns_val,
output_bc_patterns_val,
input_types,
output_types,
)
if len(outputs) == 1:
if inplace_pattern:
assert inplace_pattern[0][0] == 0
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
return outputs[0]._getvalue()
for inplace_idx in dict(inplace_pattern):
ctx.nrt.incref(
builder,
sig.return_type.types[inplace_idx],
outputs[inplace_idx]._get_value(),
)
return ctx.make_tuple(
builder, sig.return_type, [out._getvalue() for out in outputs]
)
ret_types = [
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
]
for output_idx, input_idx in inplace_pattern:
ret_types[output_idx] = input_types[input_idx]
ret_type = types.Tuple(ret_types)
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
return sig, codegen
def compute_itershape( def compute_itershape(
ctx: BaseContext, ctx: BaseContext,
builder: ir.IRBuilder, builder: ir.IRBuilder,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论