提交 a0dadf5d authored 作者: Gijs van Tulder's avatar Gijs van Tulder

C implementation for downsample average

上级 19515020
......@@ -317,7 +317,7 @@ class DownsampleFactorMax(Op):
return ['<algorithm>']
def c_code(self, node, name, inp, out, sub):
if self.mode != 'max':
if self.mode not in ('max', 'average_exc_pad', 'average_inc_pad'):
raise theano.gof.utils.MethodNotDefined()
x, = inp
z, = out
......@@ -326,7 +326,7 @@ class DownsampleFactorMax(Op):
ds0, ds1 = self.ds
st0, st1 = self.st
pd0, pd1 = self.padding
return """
ccode = """
int typenum = PyArray_ObjectType((PyObject*)%(x)s, 0);
int z_r, z_c; // shape of the output
int r, c; // shape of the padded_input
......@@ -409,7 +409,7 @@ class DownsampleFactorMax(Op):
// used for indexing a pool region inside the input
int r_st, r_end, c_st, c_end;
dtype_%(x)s maximum; // temp var for maximum value in a region
dtype_%(x)s collector; // temp var for the value in a region
if (z_r && z_c)
{
for(int b=0; b<PyArray_DIMS(%(x)s)[0]; b++){
......@@ -445,28 +445,55 @@ class DownsampleFactorMax(Op):
{
c_end = c_end > c ? c : c_end;
}
"""
if self.mode == 'max':
ccode += """
// use the first element as the initial value of maximum
maximum = ((dtype_%(x)s*)(PyArray_GETPTR4(%(x)s,b,k,r_st,c_st)))[0];
collector = ((dtype_%(x)s*)(PyArray_GETPTR4(%(x)s,b,k,r_st,c_st)))[0];
// go through the pooled region in the unpadded input
for(int m=r_st; m<r_end; m++)
{
for(int n=c_st; n<c_end; n++)
{
dtype_%(x)s a = ((dtype_%(x)s*)(PyArray_GETPTR4(%(x)s,b,k,m,n)))[0];
collector = (a > collector) ? a : collector;
}
}
z[0] = collector;
"""
elif self.mode == 'average_exc_pad' or self.mode == 'average_inc_pad':
ccode += """
// initialize the sum at zero
collector = ((dtype_%(x)s)(0));
// go through the pooled region in the unpadded input
for(int m=r_st; m<r_end; m++)
{
for(int n=c_st; n<c_end; n++)
{
dtype_%(x)s a = ((dtype_%(x)s*)(PyArray_GETPTR4(%(x)s,b,k,m,n)))[0];
maximum = (a > maximum) ? a : maximum;
collector += a;
}
}
z[0] = maximum;
"""
if self.mode == 'average_inc_pad' and self.ignore_border:
ccode += """
z[0] = collector / (%(ds0)s * %(ds1)s);
"""
else:
ccode += """
z[0] = collector / ((r_end-r_st)*(c_end-c_st));
"""
ccode += """
}
}
}
}
}
""" % locals()
"""
return ccode % locals()
def c_code_cache_version(self):
return (0, 6)
return (0, 6, 8, 1)
class DownsampleFactorMaxGrad(Op):
__props__ = ('ds', 'ignore_border', 'st', 'padding', 'mode')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论