提交 9a55cfe6 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add shape checking in numba elemwise

上级 69eb09ad
......@@ -454,10 +454,6 @@ def _vectorized(
inplace_pattern,
inputs,
):
#if not isinstance(scalar_func, types.Literal):
# raise TypingError("scalar func must be literal.")
#scalar_func = scalar_func.literal_value
arg_types = [
scalar_func,
input_bc_patterns,
......@@ -516,8 +512,6 @@ def _vectorized(
inplace_pattern_val = inplace_pattern
input_types = inputs
#assert not inplace_pattern_val
def codegen(
ctx,
builder,
......@@ -551,18 +545,6 @@ def _vectorized(
input_types,
)
def _check_input_shapes(*_):
# TODO impl
return
_check_input_shapes(
ctx,
builder,
iter_shape,
inputs,
input_bc_patterns_val,
)
elemwise_codegen.make_loop_call(
typingctx,
ctx,
......@@ -594,7 +576,6 @@ def _vectorized(
builder, sig.return_type, [out._getvalue() for out in outputs]
)
# TODO check inplace_pattern
ret_type = types.Tuple(
[
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
......
......@@ -3,32 +3,55 @@ import numpy as np
from llvmlite import ir
from numba import types
from numba.core import cgutils
from numba.core.base import BaseContext
from numba.np import arrayobj
def compute_itershape(
ctx,
ctx: BaseContext,
builder: ir.IRBuilder,
in_shapes,
broadcast_pattern,
):
one = ir.IntType(64)(1)
ndim = len(in_shapes[0])
#shape = [ir.IntType(64)(1) for _ in range(ndim)]
shape = [None] * ndim
for i in range(ndim):
# TODO Error checking...
# What if all shapes are 0?
for bc, in_shape in zip(broadcast_pattern, in_shapes):
for j, (bc, in_shape) in enumerate(
zip(broadcast_pattern, in_shapes, strict=True)
):
length = in_shape[i]
if bc[i]:
# TODO
# raise error if length != 1
pass
with builder.if_then(
builder.icmp_unsigned("!=", length, one), likely=False
):
msg = (
f"Input {j} to elemwise is expected to have shape 1 in axis {i}"
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
elif shape[i] is not None:
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
then,
otherwise,
):
with then:
msg = (
f"Incompative shapes for input {j} and axis {i} of "
f"elemwise. Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise:
msg = (
f"Input {j} to elemwise has an incompatible "
f"shape in axis {i}."
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
else:
# TODO
# if shape[i] is not None:
# raise Error if !=
shape[i] = in_shape[i]
shape[i] = length
for i in range(ndim):
if shape[i] is None:
shape[i] = one
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论