提交 a22086b0 authored 作者: Frederic's avatar Frederic 提交者: David Warde-Farley

Add check that shape gived to ConvOp are the one reveiced in input.

上级 bbd9e02f
...@@ -858,7 +858,7 @@ class ConvOp(Op): ...@@ -858,7 +858,7 @@ class ConvOp(Op):
return ['<numpy/noprefix.h>', '<iostream>', '<sstream>' ] return ['<numpy/noprefix.h>', '<iostream>', '<sstream>' ]
def c_code_cache_version(self): def c_code_cache_version(self):
return (4) return (5)
def c_support_code(self): def c_support_code(self):
return """ return """
...@@ -942,6 +942,76 @@ using namespace std; ...@@ -942,6 +942,76 @@ using namespace std;
d["all_shape"]="1" d["all_shape"]="1"
d["dim_zz_const"]="const" d["dim_zz_const"]="const"
d["dim_zz_affect"]="" d["dim_zz_affect"]=""
d["assert_size"]="""
// Check the batch size and the number of kernel (sometimes constant in the graph)
if(img2d_dim[0] != %(self_bsize)s!=0){
PyErr_Format(PyExc_ValueError,
"the batch size in the image(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)img2d_dim[0], (long)%(self_bsize)s);
%(fail)s;
}
if(kerns_dim[0] != %(self_nkern)s!=0){
PyErr_Format(PyExc_ValueError,
"the number of kernel in the filter(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)kerns_dim[0], (long)%(self_nkern)s);
%(fail)s;
}
// Check the size of the image (sometimes constant in the graph)
if(img2d_dim[1] != %(self_imshp0)s){
PyErr_Format(PyExc_ValueError,
"the stack size in the image(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)img2d_dim[1], (long)%(self_imshp0)s);
%(fail)s;
}
if(img2d_dim[2] != %(self_imshp1)s){
PyErr_Format(PyExc_ValueError,
"the number of row in the image(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)img2d_dim[2], (long)%(self_imshp1)s);
%(fail)s;
}
if(img2d_dim[3] != %(self_imshp2)s){
PyErr_Format(PyExc_ValueError,
"the number of col in the image(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)img2d_dim[3], (long)%(self_imshp2)s);
%(fail)s;
}
// Check the size of the output (sometimes constant in the graph)
if(dim_zz[0] != %(self_outshp0)s!=0){
PyErr_Format(PyExc_ValueError,
"the precomputed number of row in the output(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)dim_zz[0], (long)%(self_outshp0)s);
%(fail)s;
}
if(dim_zz[1] != %(self_outshp1)s!=0){
PyErr_Format(PyExc_ValueError,
"the precomputed number of col in the output(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)dim_zz[1], (long)%(self_outshp1)s);
%(fail)s;
}
// Check the size of the filter (sometimes constant in the graph)
if(kerns_dim[1] %% %(self_imshp0)s!=0){
PyErr_Format(PyExc_ValueError,
"the stack size in the filter(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)kerns_dim[1], (long)%(self_imshp0)s);
%(fail)s;
}
if(kerns_dim[2] %% %(self_kshp0)s!=0){
PyErr_Format(PyExc_ValueError,
"the number of row in the filter(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)kerns_dim[2], (long)%(self_kshp0)s);
%(fail)s;
}
if(kerns_dim[3] %% %(self_kshp1)s!=0){
PyErr_Format(PyExc_ValueError,
"the number of columns in the filter(%%ld) at run time is different then at build time(%%ld) for the ConvOp.",
(long)kerns_dim[3], (long)%(self_kshp1)s);
%(fail)s;
}
"""%(locals())
else: else:
d["self_bsize"]="%(img2d)s->dimensions[0]"%d d["self_bsize"]="%(img2d)s->dimensions[0]"%d
d["self_nkern"]="%(filtersflipped)s->dimensions[0]"%d d["self_nkern"]="%(filtersflipped)s->dimensions[0]"%d
...@@ -964,6 +1034,7 @@ using namespace std; ...@@ -964,6 +1034,7 @@ using namespace std;
dim_zz[1] = (int)ceil((dim_im[1]-dim_ker1+1)/float(%(self_dy)s)); dim_zz[1] = (int)ceil((dim_im[1]-dim_ker1+1)/float(%(self_dy)s));
} }
"""% d """% d
d["assert_size"]=""
if self.kshp_logical_top_aligned: if self.kshp_logical_top_aligned:
d["self_kshp_logical_offset_r"] = 0 d["self_kshp_logical_offset_r"] = 0
...@@ -1072,6 +1143,8 @@ if(%(filtersflipped)s->nd==3){ ...@@ -1072,6 +1143,8 @@ if(%(filtersflipped)s->nd==3){
%(fail)s; %(fail)s;
} }
%(assert_size)s
img2d = PyArray_Newshape(%(img2d)s,&img2d_shape, PyArray_CORDER); img2d = PyArray_Newshape(%(img2d)s,&img2d_shape, PyArray_CORDER);
img2d_arr = (PyArrayObject*)img2d; img2d_arr = (PyArrayObject*)img2d;
if ((img2d_arr->strides[3] != (npy_intp)sizeof(%(type)s)) if ((img2d_arr->strides[3] != (npy_intp)sizeof(%(type)s))
...@@ -1355,6 +1428,8 @@ if ((!%(z)s) ...@@ -1355,6 +1428,8 @@ if ((!%(z)s)
PyArray_FILLWBYTE((PyObject*)%(z)s,0); PyArray_FILLWBYTE((PyObject*)%(z)s,0);
} }
%(assert_size)s
int Os[2]; int Os[2];
Os[0] = dim_im[0]-dim_ker0+1; Os[0] = dim_im[0]-dim_ker0+1;
Os[1] = dim_im[1]-dim_ker1+1; Os[1] = dim_im[1]-dim_ker1+1;
...@@ -1555,18 +1630,7 @@ if(%(filtersflipped)s->nd==3){ ...@@ -1555,18 +1630,7 @@ if(%(filtersflipped)s->nd==3){
%(fail)s; %(fail)s;
} }
if(img2d_dim[0] %% %(self_bsize)s!=0){ %(assert_size)s
PyErr_Format(PyExc_ValueError,
"the batch size of the image(%%ld) must be a multiple of the bsize value at ConvOp construction(%%ld).",
(long)img2d_dim[0],(long)%(self_bsize)s);
%(fail)s;
}
if(kerns_dim[0] %% %(self_nkern)s!=0){
PyErr_Format(PyExc_ValueError,
"the number of kernel(%%ld) must be a multiple of the nkern value at ConvOp construction(%%ld).",
(long)kerns_dim[0], (long)%(self_nkern)s);
%(fail)s;
}
img2d = PyArray_Newshape(%(img2d)s,&img2d_shape, PyArray_CORDER); img2d = PyArray_Newshape(%(img2d)s,&img2d_shape, PyArray_CORDER);
img2d_arr = (PyArrayObject*)img2d; img2d_arr = (PyArrayObject*)img2d;
...@@ -1799,18 +1863,7 @@ if(%(filtersflipped)s->nd==3){ ...@@ -1799,18 +1863,7 @@ if(%(filtersflipped)s->nd==3){
%(fail)s; %(fail)s;
} }
if(img2d_dim[0] != %(self_bsize)s){ %(assert_size)s
PyErr_Format(PyExc_ValueError,
"the batch size of the image(%%ld) must be a multiple of the bsize value at ConvOp construction(%%ld).",
(long)img2d_dim[0],(long)%(self_bsize)s);
%(fail)s;
}
if(kerns_dim[0] != %(self_nkern)s){
PyErr_Format(PyExc_ValueError,
"the number of kernel(%%ld) must be a multiple of the nkern value at ConvOp construction(%%ld).",
(long)kerns_dim[0], (long)%(self_nkern)s);
%(fail)s;
}
img2d = PyArray_Newshape(%(img2d)s,&img2d_shape, PyArray_CORDER); img2d = PyArray_Newshape(%(img2d)s,&img2d_shape, PyArray_CORDER);
img2d_arr = (PyArrayObject*)img2d; img2d_arr = (PyArrayObject*)img2d;
......
...@@ -220,9 +220,58 @@ class TestConv2D(unittest.TestCase): ...@@ -220,9 +220,58 @@ class TestConv2D(unittest.TestCase):
""" """
Tests scenario where filter_shape[1] != input_shape[1] Tests scenario where filter_shape[1] != input_shape[1]
""" """
def f(): self.assertRaises(AssertionError, self.validate, (3,2,8,8), (4,3,5,5),
self.validate((3,2,8,8), (4,3,5,5), 'valid') 'valid')
self.assertRaises(AssertionError, f) def test_invalid_input_shape(self):
"""
Tests that when the shape gived at build time is not the same as
run time we raise an error
"""
for unroll_batch in [None, 1, 3]:
for unroll_kern in [None, 2, 4]:
for unroll_patch in [None, True, False]:
for mode in ['valid', 'full']:
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_image_shape = (2,2,8,8),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_image_shape = (3,1,8,8),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_image_shape = (3,2,7,8),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_image_shape = (3,2,8,7),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_filter_shape = (3,2,5,5),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_filter_shape = (4,1,5,5),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_filter_shape = (4,2,6,5),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
self.assertRaises(ValueError, self.validate, (3,2,8,8), (4,2,5,5),
mode, N_filter_shape = (4,2,5,6),
unroll_batch=unroll_batch,
unroll_kern=unroll_kern,
unroll_patch=unroll_patch)
def test_missing_info(self): def test_missing_info(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论