提交 981be2a6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Revert numba runtime broadcast check

上级 5bbfc964
...@@ -35,20 +35,15 @@ def compute_itershape( ...@@ -35,20 +35,15 @@ def compute_itershape(
with builder.if_then( with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False builder.icmp_unsigned("!=", length, shape[i]), likely=False
): ):
with builder.if_else( with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
builder.or_(
builder.icmp_unsigned("==", length, one),
builder.icmp_unsigned("==", shape[i], one),
)
) as (
then, then,
otherwise, otherwise,
): ):
with then: with then:
msg = ( msg = (
"Runtime broadcasting not allowed. " f"Incompatible shapes for input {j} and axis {i} of "
"One input had a distinct dimension length of 1, but was not marked as broadcastable.\n" f"elemwise. Input {j} has shape 1, but is not statically "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input." "known to have shape 1, and thus not broadcastable."
) )
ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise: with otherwise:
......
...@@ -121,6 +121,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): ...@@ -121,6 +121,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals) compare_numba_and_py(out_fg, input_vals)
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
def test_elemwise_runtime_shape_error(): def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA")) TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论