提交 cdfa4964 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add test for the previous fix and add more check to prevent other crash with bad…

Add test for the previous fix and add more check to prevent other crash with bad neib_shape or neib_step input value.
上级 2ce0646e
......@@ -152,7 +152,7 @@ class Images2Neibs(Op):
grad_undefined(self, 2, neib_step)]
def c_code_cache_version(self):
return (5,)
return (6,)
def perform(self, node, inp, out_):
ten4, neib_shape, neib_step = inp
......@@ -317,8 +317,24 @@ class Images2Neibs(Op):
const npy_intp c = (npy_intp) *(dtype_%(neib_shape)s*) PyArray_GETPTR1(%(neib_shape)s, 0);
const npy_intp d = (npy_intp) *(dtype_%(neib_shape)s*) PyArray_GETPTR1(%(neib_shape)s, 1);
// (step_x,step_y) = neib_step
const npy_intp step_x = (npy_intp) *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 0);
const npy_intp step_y = (npy_intp) *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 1);
const dtype_%(neib_step)s step_x = *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 0);
const dtype_%(neib_step)s step_y = *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 1);
if (step_x <=0 || step_y <=0)
{
PyErr_Format(PyExc_ValueError,
"neib_step wrong step ; values <= 0. Got %%d %%d.",
step_x, step_y);
%(fail)s;
}
if (c <=0 || d <=0)
{
PyErr_Format(PyExc_ValueError,
"neib_shape values <= 0. Got %%d %%d.",
c, d);
%(fail)s;
}
if ( "%(mode)s" == "wrap_centered") {
if (c%%2!=1 || d%%2!=1){
......
......@@ -340,6 +340,31 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
mode=self.mode)
self.assertRaises(TypeError, f, images_val)
def test_can_not_infer_nb_dim(self):
# Was reported in gh-5613. Test that we do not crash
# or that we crash in a few other case found while
# investigating that case
img = T.tensor4('img')
patches = T.nnet.neighbours.images2neibs(img, [16, 16])
extractPatches = theano.function([img], patches)
patsRecovery = T.matrix('patsRecovery')
original_size = T.ivector('original_size')
out = neibs2images(patsRecovery, (16, 16), original_size)
f = theano.function([patsRecovery, original_size], out)
im_val = numpy.ones((1, 3, 320, 320), dtype=numpy.float32)
neibs = extractPatches(im_val)
f(neibs, im_val.shape)
# Wrong number of dimensions
self.assertRaises(ValueError, f, neibs,
(1, 1, 3, 320, 320))
# End up with a step of 0
self.assertRaises(ValueError, f, neibs,
(3, 320, 320, 1))
def speed_neibs(self):
shape = (100, 40, 18, 18)
images = shared(numpy.arange(numpy.prod(shape),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论