提交 20f280ff authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed elemwise compiler to check the dimensionality of the inputs

上级 0a02fb52
...@@ -390,6 +390,26 @@ def elemwise_loopcode(loopcode, init_template, next_template, acquire_template, ...@@ -390,6 +390,26 @@ def elemwise_loopcode(loopcode, init_template, next_template, acquire_template,
def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, aliases): def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, aliases):
check_init = """
npy_intp nd = %(loop_var)s->nd;
npy_intp* dims = %(loop_var)s->dimensions;
npy_intp* dims2;
"""
check = """
if (%(loop_var)s->nd != nd) {
PyErr_SetString(PyExc_ValueError, \"The number of dimensions of the inputs do not match.\");
}
dims2 = %(loop_var)s->dimensions;
for (int i = 0; i < nd; i++) {
if (dims2[i] != dims[i]) {
PyErr_SetString(PyExc_ValueError, \"The dimensions of the inputs do not match.\");
return 1;
}
}
"""
general_init = "PyArrayIterObject* %(loop_var)s_iter = (PyArrayIterObject*)PyArray_IterNew((PyObject*)%(loop_var)s);\n" general_init = "PyArrayIterObject* %(loop_var)s_iter = (PyArrayIterObject*)PyArray_IterNew((PyObject*)%(loop_var)s);\n"
# "if (%(loop_var)s_iter == NULL) {\n" \ # "if (%(loop_var)s_iter == NULL) {\n" \
# " PyErr_SetString(PyExc_ValueError, \"Could not make an iterator over variable %(loop_var)s.\");\n" \ # " PyErr_SetString(PyExc_ValueError, \"Could not make an iterator over variable %(loop_var)s.\");\n" \
...@@ -405,8 +425,11 @@ def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, ...@@ -405,8 +425,11 @@ def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars,
contiguous_cleanup = "" contiguous_cleanup = ""
all_loop_vars = loop_vars + writable_loop_vars all_loop_vars = loop_vars + writable_loop_vars
v1 = (loop_vars + writable_loop_vars)[0]
template = dict( template = dict(
v1 = (loop_vars + writable_loop_vars)[0], v1 = v1,
check_init = check_init % dict(loop_var = v1),
check = "\n".join([check % dict(loop_var = loop_var) for loop_var in loop_vars + writable_loop_vars if loop_var is not v1]),
beforeloop = beforeloop, beforeloop = beforeloop,
general_loop = elemwise_loopcode( general_loop = elemwise_loopcode(
inloop, inloop,
...@@ -423,6 +446,10 @@ def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, ...@@ -423,6 +446,10 @@ def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars,
afterloop = afterloop) afterloop = afterloop)
code = """ code = """
{
%(check_init)s
%(check)s
}
npy_intp __elemwise_size = PyArray_SIZE(%(v1)s); npy_intp __elemwise_size = PyArray_SIZE(%(v1)s);
%(beforeloop)s %(beforeloop)s
bool all_c_contiguous = 1; bool all_c_contiguous = 1;
...@@ -727,6 +754,8 @@ class NumpyR(gof.PythonR): ...@@ -727,6 +754,8 @@ class NumpyR(gof.PythonR):
def wrap_producer(f): def wrap_producer(f):
class producer(omega_op): class producer(omega_op):
impl = f impl = f
def grad(*args):
return [UNDEFINED] * (len(args) - 1)
producer.__name__ = f.__name__ producer.__name__ = f.__name__
def ret(dim, dtype = 'float', order = 'C'): def ret(dim, dtype = 'float', order = 'C'):
return producer(dim, dtype, order) return producer(dim, dtype, order)
...@@ -1473,7 +1502,7 @@ class _testCase_slicing(unittest.TestCase): ...@@ -1473,7 +1502,7 @@ class _testCase_slicing(unittest.TestCase):
try: try:
err = wa1 + a err = wa1 + a
except ValueError, e: except ValueError, e:
self.failUnless(e.message == \ self.failUnless(str(e) == \
'The dimensions of the inputs do not match.', 'The dimensions of the inputs do not match.',
'Wrong ValueError') 'Wrong ValueError')
return return
...@@ -1489,7 +1518,7 @@ class _testCase_slicing(unittest.TestCase): ...@@ -1489,7 +1518,7 @@ class _testCase_slicing(unittest.TestCase):
try: try:
wa1 = wrap(a)[:] wa1 = wrap(a)[:]
except IndexError, e: except IndexError, e:
self.failUnless(e.message == "0-d arrays can't be indexed.") self.failUnless(str(e) == "0-d arrays can't be indexed.")
return return
self.fail() self.fail()
def test_getslice_1d_all(self): def test_getslice_1d_all(self):
...@@ -1505,7 +1534,7 @@ class _testCase_slicing(unittest.TestCase): ...@@ -1505,7 +1534,7 @@ class _testCase_slicing(unittest.TestCase):
try: try:
wa1[2] = 2.5 wa1[2] = 2.5
except TypeError, e: except TypeError, e:
self.failUnless(e.message == "'NumpyR' object does not support item assignment") self.failUnless("object does not support item assignment" in str(e))
return return
self.fail() self.fail()
def test_getslice_3d_all(self): def test_getslice_3d_all(self):
......
...@@ -236,7 +236,7 @@ class _testCase (unittest.TestCase): ...@@ -236,7 +236,7 @@ class _testCase (unittest.TestCase):
gb(a) gb(a)
self.assertEqual('should have raised',0) self.assertEqual('should have raised',0)
except Exception, e: except Exception, e:
self.assertEqual(e.message, 'Grad.__call__ only makes sense after a bprop') self.assertEqual(str(e), 'Grad.__call__ only makes sense after a bprop')
return return
self.assertEqual('should have caught, returned',0) self.assertEqual('should have caught, returned',0)
...@@ -264,7 +264,7 @@ class _testCase (unittest.TestCase): ...@@ -264,7 +264,7 @@ class _testCase (unittest.TestCase):
gb.bprop() gb.bprop()
self.assertEqual('should have raised',0) self.assertEqual('should have raised',0)
except AttributeError, e: except AttributeError, e:
self.assertEqual(e.message, "Keyword instance has no attribute 'shape'") self.assertEqual(str(e), "Keyword instance has no attribute 'shape'")
return return
self.assertEqual("Should have been error", 0) self.assertEqual("Should have been error", 0)
...@@ -278,7 +278,7 @@ class _testCase (unittest.TestCase): ...@@ -278,7 +278,7 @@ class _testCase (unittest.TestCase):
gc.bprop() gc.bprop()
self.assertEqual('should have raised',0) self.assertEqual('should have raised',0)
except AttributeError, e: except AttributeError, e:
self.assertEqual(e.message, "Keyword instance has no attribute 'shape'") self.assertEqual(str(e), "Keyword instance has no attribute 'shape'")
return return
self.assertEqual("Should have been error", 0) self.assertEqual("Should have been error", 0)
...@@ -307,7 +307,7 @@ class _testCase (unittest.TestCase): ...@@ -307,7 +307,7 @@ class _testCase (unittest.TestCase):
g.bprop() g.bprop()
self.assertEqual('should have raised') self.assertEqual('should have raised')
except Exception, e: except Exception, e:
self.assertEqual(e.message, 'bprop has already been done. Consider calling with maybe_redo=True.') self.assertEqual(str(e), 'bprop has already been done. Consider calling with maybe_redo=True.')
return return
self.assertEqual('should have caught') self.assertEqual('should have caught')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论