Unverified 提交 a5fb9113 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix error formatting in C-Elemwise (#1749)

上级 e274e0d2
...@@ -1099,7 +1099,7 @@ class Elemwise(OpenMPOp): ...@@ -1099,7 +1099,7 @@ class Elemwise(OpenMPOp):
return support_code return support_code
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
version = [16] # the version corresponding to the c code in this Op version = [17] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply( scalar_node = Apply(
......
...@@ -85,7 +85,7 @@ def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True): ...@@ -85,7 +85,7 @@ def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True):
runtime_broadcast_error_msg = ( runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. " "Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable: " "One input had a distinct dimension length of 1, but was not marked as broadcastable: "
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). " "(input[%i].shape[%i] = %lld, input[%i].shape[%i] = %lld). "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input." "If broadcasting was intended, use `specify_broadcastable` on the relevant input."
) )
...@@ -113,7 +113,7 @@ def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True): ...@@ -113,7 +113,7 @@ def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True):
(long long int) {sub[f"lv{j}"]}_n{x} (long long int) {sub[f"lv{j}"]}_n{x}
); );
}} else {{ }} else {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)", PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%i].shape[%i] = %lld, input[%i].shape[%i] = %lld)",
{j0}, {j0},
{x0}, {x0},
(long long int) {sub[f"lv{j0}"]}_n{x0}, (long long int) {sub[f"lv{j0}"]}_n{x0},
......
...@@ -32,6 +32,7 @@ from pytensor.tensor.type import ( ...@@ -32,6 +32,7 @@ from pytensor.tensor.type import (
bmatrix, bmatrix,
bscalar, bscalar,
discrete_dtypes, discrete_dtypes,
dmatrix,
lscalar, lscalar,
matrix, matrix,
scalar, scalar,
...@@ -832,7 +833,26 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -832,7 +833,26 @@ class TestElemwise(unittest_tools.InferShapeTester):
reason="G++ not available, so we need to skip this test.", reason="G++ not available, so we need to skip this test.",
) )
def test_runtime_broadcast_c(self): def test_runtime_broadcast_c(self):
check_elemwise_runtime_broadcast(Mode(linker="c")) c_mode = Mode(linker="cvm")
check_elemwise_runtime_broadcast(c_mode)
# Test C-backend specific error formatting
x = dmatrix("x")
y = dmatrix("y")
fn = function([x, y], x * y, mode=c_mode)
with pytest.raises(
ValueError,
match=r"Runtime broadcasting not allowed.*\(input\[0\]\.shape\[1\] = 4, input\[1\]\.shape\[1\] = 1\)",
):
fn(np.zeros((5, 4)), np.zeros((5, 1)))
with pytest.raises(
ValueError,
match=re.escape(
"Input dimension mismatch: (input[0].shape[1] = 4, input[1].shape[1] = 3)"
),
):
fn(np.zeros((5, 4)), np.zeros((5, 3)))
def test_str(self): def test_str(self):
op = Elemwise(ps.add, inplace_pattern={0: 0}, name=None) op = Elemwise(ps.add, inplace_pattern={0: 0}, name=None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论