提交 9a55cfe6 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add shape checking in numba elemwise

上级 69eb09ad
...@@ -454,10 +454,6 @@ def _vectorized( ...@@ -454,10 +454,6 @@ def _vectorized(
inplace_pattern, inplace_pattern,
inputs, inputs,
): ):
#if not isinstance(scalar_func, types.Literal):
# raise TypingError("scalar func must be literal.")
#scalar_func = scalar_func.literal_value
arg_types = [ arg_types = [
scalar_func, scalar_func,
input_bc_patterns, input_bc_patterns,
...@@ -516,8 +512,6 @@ def _vectorized( ...@@ -516,8 +512,6 @@ def _vectorized(
inplace_pattern_val = inplace_pattern inplace_pattern_val = inplace_pattern
input_types = inputs input_types = inputs
#assert not inplace_pattern_val
def codegen( def codegen(
ctx, ctx,
builder, builder,
...@@ -551,18 +545,6 @@ def _vectorized( ...@@ -551,18 +545,6 @@ def _vectorized(
input_types, input_types,
) )
def _check_input_shapes(*_):
# TODO impl
return
_check_input_shapes(
ctx,
builder,
iter_shape,
inputs,
input_bc_patterns_val,
)
elemwise_codegen.make_loop_call( elemwise_codegen.make_loop_call(
typingctx, typingctx,
ctx, ctx,
...@@ -594,7 +576,6 @@ def _vectorized( ...@@ -594,7 +576,6 @@ def _vectorized(
builder, sig.return_type, [out._getvalue() for out in outputs] builder, sig.return_type, [out._getvalue() for out in outputs]
) )
# TODO check inplace_pattern
ret_type = types.Tuple( ret_type = types.Tuple(
[ [
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
......
...@@ -3,32 +3,55 @@ import numpy as np ...@@ -3,32 +3,55 @@ import numpy as np
from llvmlite import ir from llvmlite import ir
from numba import types from numba import types
from numba.core import cgutils from numba.core import cgutils
from numba.core.base import BaseContext
from numba.np import arrayobj from numba.np import arrayobj
def compute_itershape( def compute_itershape(
ctx, ctx: BaseContext,
builder: ir.IRBuilder, builder: ir.IRBuilder,
in_shapes, in_shapes,
broadcast_pattern, broadcast_pattern,
): ):
one = ir.IntType(64)(1) one = ir.IntType(64)(1)
ndim = len(in_shapes[0]) ndim = len(in_shapes[0])
#shape = [ir.IntType(64)(1) for _ in range(ndim)]
shape = [None] * ndim shape = [None] * ndim
for i in range(ndim): for i in range(ndim):
# TODO Error checking... for j, (bc, in_shape) in enumerate(
# What if all shapes are 0? zip(broadcast_pattern, in_shapes, strict=True)
for bc, in_shape in zip(broadcast_pattern, in_shapes): ):
length = in_shape[i]
if bc[i]: if bc[i]:
# TODO with builder.if_then(
# raise error if length != 1 builder.icmp_unsigned("!=", length, one), likely=False
pass ):
msg = (
f"Input {j} to elemwise is expected to have shape 1 in axis {i}"
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
elif shape[i] is not None:
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
then,
otherwise,
):
with then:
msg = (
f"Incompative shapes for input {j} and axis {i} of "
f"elemwise. Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise:
msg = (
f"Input {j} to elemwise has an incompatible "
f"shape in axis {i}."
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
else: else:
# TODO shape[i] = length
# if shape[i] is not None:
# raise Error if !=
shape[i] = in_shape[i]
for i in range(ndim): for i in range(ndim):
if shape[i] is None: if shape[i] is None:
shape[i] = one shape[i] = one
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论