提交 75d7e687 authored 作者: James Bergstra's avatar James Bergstra

Made nicer exception feedback from elemwise_cgen.

上级 aa048697
...@@ -688,6 +688,9 @@ class Elemwise(Op): ...@@ -688,6 +688,9 @@ class Elemwise(Op):
def c_support_code(self): def c_support_code(self):
return self.scalar_op.c_support_code() return self.scalar_op.c_support_code()
def c_code_cache_version(self):
return (4,)
# def elemwise_to_scal(env): # def elemwise_to_scal(env):
# mapping = {} # mapping = {}
# inputs = [] # inputs = []
......
...@@ -57,20 +57,48 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -57,20 +57,48 @@ def make_checks(loop_orders, dtypes, sub):
""" % locals() """ % locals()
adjust = [] adjust = []
check = "" check = ""
for matches in zip(*loop_orders): if 0:
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"] # original dimension-checking loop builds a single if condition, and if it is true, it
if len(to_compare) < 2: # raises a generic error message
continue for matches in zip(*loop_orders):
j, x = to_compare[0] to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]
first = "%%(lv%(j)s)s_n%(x)s" % locals() if len(to_compare) < 2:
cond = " || ".join(["%(first)s != %%(lv%(j)s)s_n%(x)s" % locals() for j, x in to_compare[1:]]) continue
if cond: j, x = to_compare[0]
check += """ first = "%%(lv%(j)s)s_n%(x)s" % locals()
if (%(cond)s) { cond = " || ".join(["%(first)s != %%(lv%(j)s)s_n%(x)s" % locals() for j, x in to_compare[1:]])
PyErr_SetString(PyExc_ValueError, "Input dimensions do not match (Try re-running with py linker for more information)."); if cond:
%%(fail)s check += """
} if (%(cond)s) {
""" % locals() PyErr_SetString(PyExc_ValueError, "Input dimensions do not match (Try re-running with py linker for more information).");
%%(fail)s
}
""" % locals()
else:
# revised dimension-checking loop build multiple if conditions, and the first one that
# is true raises a more informative error message
for matches in zip(*loop_orders):
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]
#elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx )
if len(to_compare) < 2:
continue
j0, x0 = to_compare[0]
for (j, x) in to_compare[1:]:
check += """
if (%%(lv%(j0)s)s_n%(x0)s != %%(lv%(j)s)s_n%(x)s)
{
PyErr_Format(PyExc_ValueError, "Input dimension mis-match. (input[%%%%i].shape[%%%%i] = %%%%i, input[%%%%i].shape[%%%%i] = %%%%i)",
%(j0)s,
%(x0)s,
%%(lv%(j0)s)s_n%(x0)s,
%(j)s,
%(x)s,
%%(lv%(j)s)s_n%(x)s
);
%%(fail)s
}
""" % locals()
return init % sub + check % sub return init % sub + check % sub
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论