提交 f8549b02 authored 作者: Frederic Bastien's avatar Frederic Bastien

Use params for ignore border in Pool op.

上级 d7cd25c5
...@@ -14,7 +14,9 @@ from six.moves import xrange ...@@ -14,7 +14,9 @@ from six.moves import xrange
import six.moves.builtins as builtins import six.moves.builtins as builtins
import theano import theano
from theano import gof, OpenMPOp, tensor, Variable, Apply from theano import gof, OpenMPOp, tensor, Variable, Apply
from theano.gof.params_type import ParamsType
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.scalar import bool as bool_t
def max_pool_2d_same_size(input, patch_size): def max_pool_2d_same_size(input, patch_size):
...@@ -294,6 +296,7 @@ class Pool(OpenMPOp): ...@@ -294,6 +296,7 @@ class Pool(OpenMPOp):
""" """
__props__ = ('ignore_border', 'mode', 'ndim') __props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(ignore_border=bool_t,)
@staticmethod @staticmethod
def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None, def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None,
...@@ -508,7 +511,7 @@ class Pool(OpenMPOp): ...@@ -508,7 +511,7 @@ class Pool(OpenMPOp):
out = tensor.TensorType(x.dtype, broad) out = tensor.TensorType(x.dtype, broad)
return gof.Apply(self, [x, ws, stride, pad], [out()]) return gof.Apply(self, [x, ws, stride, pad], [out()])
def perform(self, node, inp, out): def perform(self, node, inp, out, params):
x, ws, stride, pad = inp x, ws, stride, pad = inp
z, = out z, = out
nd = self.ndim nd = self.ndim
...@@ -516,8 +519,8 @@ class Pool(OpenMPOp): ...@@ -516,8 +519,8 @@ class Pool(OpenMPOp):
if len(x.shape) < nd: if len(x.shape) < nd:
raise NotImplementedError( raise NotImplementedError(
'Pool requires input with {} or more dimensions'.format(nd)) 'Pool requires input with {} or more dimensions'.format(nd))
z_shape = self.out_shape(x.shape, ws, self.ignore_border, stride, pad, nd) z_shape = self.out_shape(x.shape, ws, params.ignore_border, stride, pad, nd)
if not self.ignore_border: if not params.ignore_border:
assert all(z > 0 for z in z_shape[-nd:]) assert all(z > 0 for z in z_shape[-nd:])
if (z[0] is None) or (z[0].shape != z_shape): if (z[0] is None) or (z[0].shape != z_shape):
z[0] = np.empty(z_shape, dtype=x.dtype) z[0] = np.empty(z_shape, dtype=x.dtype)
...@@ -617,7 +620,7 @@ class Pool(OpenMPOp): ...@@ -617,7 +620,7 @@ class Pool(OpenMPOp):
total_ndim = node.inputs[0].ndim total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd non_pool_ndim = total_ndim - nd
fail = sub['fail'] fail = sub['fail']
ignore_border = int(self.ignore_border) params = sub['params']
if self.openmp: if self.openmp:
# run in parallel over each pooling block # run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector) schedule(static)' omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector) schedule(static)'
...@@ -661,13 +664,13 @@ class Pool(OpenMPOp): ...@@ -661,13 +664,13 @@ class Pool(OpenMPOp):
if (pd[i]>0) if (pd[i]>0)
nonzero_padding = 1; nonzero_padding = 1;
} }
if (!%(ignore_border)s && nonzero_padding) if (!%(params)s->ignore_border && nonzero_padding)
{ {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False"); "padding must be zero when ignore border is False");
%(fail)s; %(fail)s;
} }
if (%(ignore_border)s) if (%(params)s->ignore_border)
{ {
for (int i=0; i<%(nd)s; i++) for (int i=0; i<%(nd)s; i++)
{ {
...@@ -801,13 +804,13 @@ class Pool(OpenMPOp): ...@@ -801,13 +804,13 @@ class Pool(OpenMPOp):
r_st[%(i)s] -= pd[%(i)s]; r_st[%(i)s] -= pd[%(i)s];
r_end[%(i)s] -= pd[%(i)s]; r_end[%(i)s] -= pd[%(i)s];
// handle the case where no padding, ignore border is True // handle the case where no padding, ignore border is True
if (%(ignore_border)s) if (%(params)s->ignore_border)
{ {
r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s]; r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s];
} }
// use the index to find the correct position in the output // use the index to find the correct position in the output
o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s]; o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s];
""" % dict(i=i, ignore_border=ignore_border, non_pool_ndim=non_pool_ndim) """ % dict(i=i, non_pool_ndim=non_pool_ndim, params=sub['params'])
ccode += """ ccode += """
// get a pointer to the correct position in the output // get a pointer to the correct position in the output
...@@ -907,7 +910,7 @@ class Pool(OpenMPOp): ...@@ -907,7 +910,7 @@ class Pool(OpenMPOp):
return ccode % locals() return ccode % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 6, 8, 7, self.openmp) return (9, self.openmp)
class PoolGrad(OpenMPOp): class PoolGrad(OpenMPOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论