提交 2ffd049c authored 作者: Frederic Bastien's avatar Frederic Bastien

Use params for MaxPoolRop ignore_border

上级 59bdbde6
......@@ -2071,6 +2071,7 @@ class MaxPoolRop(OpenMPOp):
"""
__props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(ignore_border=bool_t,)
def __init__(self, ignore_border=False, mode='max', ndim=2, openmp=None):
super(MaxPoolRop, self).__init__(openmp=openmp)
......@@ -2115,7 +2116,7 @@ class MaxPoolRop(OpenMPOp):
out = tensor.TensorType(eval_point.dtype, broad)
return gof.Apply(self, [x, eval_point, ws, stride, pad], [out()])
def perform(self, node, inp, out):
def perform(self, node, inp, out, params):
x, ex, ws, stride, pad = inp
z, = out
nd = self.ndim
......@@ -2123,7 +2124,7 @@ class MaxPoolRop(OpenMPOp):
if len(x.shape) < nd:
raise NotImplementedError(
'Pool requires input with {} or more dimensions'.format(nd))
z_shape = Pool.out_shape(x.shape, ws, self.ignore_border, stride, pad, nd)
z_shape = Pool.out_shape(x.shape, ws, params.ignore_border, stride, pad, nd)
if not self.ignore_border:
assert all(z > 0 for z in z_shape[-nd:])
if (z[0] is None) or (z[0].shape != z_shape):
......@@ -2186,7 +2187,8 @@ class MaxPoolRop(OpenMPOp):
total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd
fail = sub['fail']
ignore_border = int(self.ignore_border)
params = sub['params']
if self.openmp:
# run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector, eval_collector) schedule(static)'
......@@ -2235,13 +2237,13 @@ class MaxPoolRop(OpenMPOp):
if (pd[i]>0)
nonzero_padding = 1;
}
if (!%(ignore_border)s && nonzero_padding)
if (!%(params)s->ignore_border && nonzero_padding)
{
PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False");
%(fail)s;
}
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
for (int i=0; i<%(nd)s; i++)
{
......@@ -2376,7 +2378,7 @@ class MaxPoolRop(OpenMPOp):
r_st[%(i)s] -= pd[%(i)s];
r_end[%(i)s] -= pd[%(i)s];
// handle the case where no padding, ignore border is True
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s];
}
......@@ -2451,4 +2453,4 @@ class MaxPoolRop(OpenMPOp):
return ccode % locals()
def c_code_cache_version(self):
return (0, self.openmp)
return (1, self.openmp)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论