提交 f8549b02 authored 作者: Frederic Bastien's avatar Frederic Bastien

Use params for ignore border in Pool op.

上级 d7cd25c5
......@@ -14,7 +14,9 @@ from six.moves import xrange
import six.moves.builtins as builtins
import theano
from theano import gof, OpenMPOp, tensor, Variable, Apply
from theano.gof.params_type import ParamsType
from theano.gradient import DisconnectedType
from theano.scalar import bool as bool_t
def max_pool_2d_same_size(input, patch_size):
......@@ -294,6 +296,7 @@ class Pool(OpenMPOp):
"""
__props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(ignore_border=bool_t,)
@staticmethod
def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None,
......@@ -508,7 +511,7 @@ class Pool(OpenMPOp):
out = tensor.TensorType(x.dtype, broad)
return gof.Apply(self, [x, ws, stride, pad], [out()])
def perform(self, node, inp, out):
def perform(self, node, inp, out, params):
x, ws, stride, pad = inp
z, = out
nd = self.ndim
......@@ -516,8 +519,8 @@ class Pool(OpenMPOp):
if len(x.shape) < nd:
raise NotImplementedError(
'Pool requires input with {} or more dimensions'.format(nd))
z_shape = self.out_shape(x.shape, ws, self.ignore_border, stride, pad, nd)
if not self.ignore_border:
z_shape = self.out_shape(x.shape, ws, params.ignore_border, stride, pad, nd)
if not params.ignore_border:
assert all(z > 0 for z in z_shape[-nd:])
if (z[0] is None) or (z[0].shape != z_shape):
z[0] = np.empty(z_shape, dtype=x.dtype)
......@@ -617,7 +620,7 @@ class Pool(OpenMPOp):
total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd
fail = sub['fail']
ignore_border = int(self.ignore_border)
params = sub['params']
if self.openmp:
# run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector) schedule(static)'
......@@ -661,13 +664,13 @@ class Pool(OpenMPOp):
if (pd[i]>0)
nonzero_padding = 1;
}
if (!%(ignore_border)s && nonzero_padding)
if (!%(params)s->ignore_border && nonzero_padding)
{
PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False");
%(fail)s;
}
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
for (int i=0; i<%(nd)s; i++)
{
......@@ -801,13 +804,13 @@ class Pool(OpenMPOp):
r_st[%(i)s] -= pd[%(i)s];
r_end[%(i)s] -= pd[%(i)s];
// handle the case where no padding, ignore border is True
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s];
}
// use the index to find the correct position in the output
o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s];
""" % dict(i=i, ignore_border=ignore_border, non_pool_ndim=non_pool_ndim)
""" % dict(i=i, non_pool_ndim=non_pool_ndim, params=sub['params'])
ccode += """
// get a pointer to the correct position in the output
......@@ -907,7 +910,7 @@ class Pool(OpenMPOp):
return ccode % locals()
def c_code_cache_version(self):
return (0, 6, 8, 7, self.openmp)
return (9, self.openmp)
class PoolGrad(OpenMPOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论