提交 20f280ff authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed elemwise compiler to check the dimensionality of the inputs

上级 0a02fb52
......@@ -390,6 +390,26 @@ def elemwise_loopcode(loopcode, init_template, next_template, acquire_template,
def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, aliases):
check_init = """
npy_intp nd = %(loop_var)s->nd;
npy_intp* dims = %(loop_var)s->dimensions;
npy_intp* dims2;
"""
check = """
if (%(loop_var)s->nd != nd) {
PyErr_SetString(PyExc_ValueError, \"The number of dimensions of the inputs do not match.\");
}
dims2 = %(loop_var)s->dimensions;
for (int i = 0; i < nd; i++) {
if (dims2[i] != dims[i]) {
PyErr_SetString(PyExc_ValueError, \"The dimensions of the inputs do not match.\");
return 1;
}
}
"""
general_init = "PyArrayIterObject* %(loop_var)s_iter = (PyArrayIterObject*)PyArray_IterNew((PyObject*)%(loop_var)s);\n"
# "if (%(loop_var)s_iter == NULL) {\n" \
# " PyErr_SetString(PyExc_ValueError, \"Could not make an iterator over variable %(loop_var)s.\");\n" \
......@@ -405,8 +425,11 @@ def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars,
contiguous_cleanup = ""
all_loop_vars = loop_vars + writable_loop_vars
v1 = (loop_vars + writable_loop_vars)[0]
template = dict(
v1 = (loop_vars + writable_loop_vars)[0],
v1 = v1,
check_init = check_init % dict(loop_var = v1),
check = "\n".join([check % dict(loop_var = loop_var) for loop_var in loop_vars + writable_loop_vars if loop_var is not v1]),
beforeloop = beforeloop,
general_loop = elemwise_loopcode(
inloop,
......@@ -423,6 +446,10 @@ def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars,
afterloop = afterloop)
code = """
{
%(check_init)s
%(check)s
}
npy_intp __elemwise_size = PyArray_SIZE(%(v1)s);
%(beforeloop)s
bool all_c_contiguous = 1;
......@@ -727,6 +754,8 @@ class NumpyR(gof.PythonR):
def wrap_producer(f):
class producer(omega_op):
impl = f
def grad(*args):
return [UNDEFINED] * (len(args) - 1)
producer.__name__ = f.__name__
def ret(dim, dtype = 'float', order = 'C'):
return producer(dim, dtype, order)
......@@ -1473,7 +1502,7 @@ class _testCase_slicing(unittest.TestCase):
try:
err = wa1 + a
except ValueError, e:
self.failUnless(e.message == \
self.failUnless(str(e) == \
'The dimensions of the inputs do not match.',
'Wrong ValueError')
return
......@@ -1489,7 +1518,7 @@ class _testCase_slicing(unittest.TestCase):
try:
wa1 = wrap(a)[:]
except IndexError, e:
self.failUnless(e.message == "0-d arrays can't be indexed.")
self.failUnless(str(e) == "0-d arrays can't be indexed.")
return
self.fail()
def test_getslice_1d_all(self):
......@@ -1505,7 +1534,7 @@ class _testCase_slicing(unittest.TestCase):
try:
wa1[2] = 2.5
except TypeError, e:
self.failUnless(e.message == "'NumpyR' object does not support item assignment")
self.failUnless("object does not support item assignment" in str(e))
return
self.fail()
def test_getslice_3d_all(self):
......
......@@ -236,7 +236,7 @@ class _testCase (unittest.TestCase):
gb(a)
self.assertEqual('should have raised',0)
except Exception, e:
self.assertEqual(e.message, 'Grad.__call__ only makes sense after a bprop')
self.assertEqual(str(e), 'Grad.__call__ only makes sense after a bprop')
return
self.assertEqual('should have caught, returned',0)
......@@ -264,7 +264,7 @@ class _testCase (unittest.TestCase):
gb.bprop()
self.assertEqual('should have raised',0)
except AttributeError, e:
self.assertEqual(e.message, "Keyword instance has no attribute 'shape'")
self.assertEqual(str(e), "Keyword instance has no attribute 'shape'")
return
self.assertEqual("Should have been error", 0)
......@@ -278,7 +278,7 @@ class _testCase (unittest.TestCase):
gc.bprop()
self.assertEqual('should have raised',0)
except AttributeError, e:
self.assertEqual(e.message, "Keyword instance has no attribute 'shape'")
self.assertEqual(str(e), "Keyword instance has no attribute 'shape'")
return
self.assertEqual("Should have been error", 0)
......@@ -307,7 +307,7 @@ class _testCase (unittest.TestCase):
g.bprop()
self.assertEqual('should have raised')
except Exception, e:
self.assertEqual(e.message, 'bprop has already been done. Consider calling with maybe_redo=True.')
self.assertEqual(str(e), 'bprop has already been done. Consider calling with maybe_redo=True.')
return
self.assertEqual('should have caught')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论