提交 b3abc664 authored 作者: Li Yao's avatar Li Yao

maxpool grad grad c code

上级 efdaacec
...@@ -694,7 +694,7 @@ class DownsampleFactorMaxGrad(Op): ...@@ -694,7 +694,7 @@ class DownsampleFactorMaxGrad(Op):
return (0, 7) return (0, 7)
class DownsampleFactorMaxGradGrad(Op): class DownsampleFactorMaxGradGrad(Op):
__props__ = ('ds', 'ignore_border', 'st', 'padding') __props__ = ('ds', 'ignore_border', 'st', 'padding', 'mode')
@staticmethod @staticmethod
def out_shape(imgshape, ds, ignore_border=False, st=None, padding=(0, 0)): def out_shape(imgshape, ds, ignore_border=False, st=None, padding=(0, 0)):
...@@ -773,7 +773,7 @@ class DownsampleFactorMaxGradGrad(Op): ...@@ -773,7 +773,7 @@ class DownsampleFactorMaxGradGrad(Op):
rval = list(imgshape[:-2]) + [nr, nc] rval = list(imgshape[:-2]) + [nr, nc]
return rval return rval
def __init__(self, ds, ignore_border, st=None, padding=(0,0)): def __init__(self, ds, ignore_border, st=None, padding=(0,0), mode='max'):
self.ds = tuple(ds) self.ds = tuple(ds)
if not all([isinstance(d, int) for d in ds]): if not all([isinstance(d, int) for d in ds]):
raise ValueError( raise ValueError(
...@@ -791,7 +791,7 @@ class DownsampleFactorMaxGradGrad(Op): ...@@ -791,7 +791,7 @@ class DownsampleFactorMaxGradGrad(Op):
if self.padding[0] >= self.ds[0] or self.padding[1] >= self.ds[1]: if self.padding[0] >= self.ds[0] or self.padding[1] >= self.ds[1]:
raise NotImplementedError( raise NotImplementedError(
'padding_h and padding_w must be smaller than strides') 'padding_h and padding_w must be smaller than strides')
self.mode = mode
def make_node(self, x, maxout, gz): def make_node(self, x, maxout, gz):
# make_node should only be called by the grad function of # make_node should only be called by the grad function of
...@@ -806,6 +806,8 @@ class DownsampleFactorMaxGradGrad(Op): ...@@ -806,6 +806,8 @@ class DownsampleFactorMaxGradGrad(Op):
return Apply(self, [x, maxout, gz], [x.type()]) return Apply(self, [x, maxout, gz], [x.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out):
if self.mode != 'max':
raise theano.gof.utils.MethodNotDefined()
x, maxout, ggx = inp x, maxout, ggx = inp
z, = out z, = out
if len(x.shape) != 4: if len(x.shape) != 4:
...@@ -815,7 +817,7 @@ class DownsampleFactorMaxGradGrad(Op): ...@@ -815,7 +817,7 @@ class DownsampleFactorMaxGradGrad(Op):
self.st, self.padding) self.st, self.padding)
if (z[0] is None) or (z[0].shape != z_shape): if (z[0] is None) or (z[0].shape != z_shape):
z[0] = numpy.zeros(z_shape, dtype=x.dtype) z[0] = numpy.zeros(z_shape, dtype=x.dtype)
ggz = z[0] ggz = z[0] # grad wrt maxout_grad has the same shape as maxout
# number of pooling output rows # number of pooling output rows
pr = ggz.shape[-2] pr = ggz.shape[-2]
# number of pooling output cols # number of pooling output cols
...@@ -855,3 +857,89 @@ class DownsampleFactorMaxGradGrad(Op): ...@@ -855,3 +857,89 @@ class DownsampleFactorMaxGradGrad(Op):
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
return [in_shapes[0]] return [in_shapes[0]]
def c_code(self, node, name, inp, out, sub):
if self.mode != 'max':
raise theano.gof.utils.MethodNotDefined()
x, maxout, ggx = inp
z, = out # the grad of grad
fail = sub['fail']
ignore_border = int(self.ignore_border)
ds0, ds1 = self.ds
st0, st1 = self.st
pd0, pd1 = self.padding
return """
int z_typenum = PyArray_ObjectType((PyObject*)%(maxout)s, 0);
int z_r, z_c;
z_r = PyArray_DIMS(%(maxout)s)[2];
z_c = PyArray_DIMS(%(maxout)s)[3];
int r, c; // shape of the padded_input
r = PyArray_DIMS(%(x)s)[2];
c = PyArray_DIMS(%(x)s)[3];
r += %(pd0)s * 2;
c += %(pd1)s * 2;
// allocating memory for output
if ((!%(z)s)
|| !PyArray_ISCONTIGUOUS(%(z)s)
|| *PyArray_DIMS(%(z)s)!=4
||(PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(maxout)s)[0])
||(PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(maxout)s)[1])
||(PyArray_DIMS(%(z)s)[2] != PyArray_DIMS(%(maxout)s)[2])
||(PyArray_DIMS(%(z)s)[3] != PyArray_DIMS(%(maxout)s)[3])
)
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_ZEROS(4, PyArray_DIMS(%(maxout)s), z_typenum,0);
}
else {
PyArray_FILLWBYTE(%(z)s, 0);
}
dtype_%(maxout)s maximum; // temp var for maximum value in a region
int r_st, r_end, c_st, c_end; // used to index into the input img x
for(int b=0; b<PyArray_DIMS(%(x)s)[0]; b++){
for(int k=0; k<PyArray_DIMS(%(x)s)[1]; k++){
for(int i=0; i< z_r; i++){
r_st = i * %(st0)s;
r_end = r_st + %(ds0)s;
// skip the padding
r_st = r_st < %(pd0)s ? %(pd0)s : r_st;
r_end = r_end > (r - %(pd0)s) ? r - %(pd0)s : r_end;
// from padded_img space to img space
r_st -= %(pd0)s;
r_end -= %(pd0)s;
for(int j=0; j<z_c; j++){
c_st = j * %(st1)s;
c_end = c_st + %(ds1)s;
// skip the padding
c_st = c_st < %(pd1)s ? %(pd1)s : c_st;
c_end = c_end > (c - %(pd1)s) ? c - %(pd1)s : c_end;
// from padding_img space into img space
c_st -= %(pd1)s;
c_end -= %(pd1)s;
// the maximum value
maximum = ((dtype_%(maxout)s*)(PyArray_GETPTR4(%(maxout)s,b,k,i,j)))[0];
// z at this position
dtype_%(z)s * z = ((dtype_%(z)s*)(PyArray_GETPTR4(%(z)s, b, k, i, j)));
// go through the pooled region in the unpadded input
for(int m=r_st; m<r_end; m++)
{
for(int n=c_st; n<c_end; n++)
{
dtype_%(x)s a = ((dtype_%(x)s*)(PyArray_GETPTR4(%(x)s,b,k,m,n)))[0];
dtype_%(ggx)s * ggx = (
(dtype_%(ggx)s*)(PyArray_GETPTR4(%(ggx)s, b, k, m, n)));
if (a == maximum){
z[0] += ggx[0];
}
}
}
}
}
}
}
"""%locals()
def c_code_cache_version(self):
return (0,1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论