提交 61cd924c authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

changes of pool.py and test_pool.py have been made based on the Pull request #5138.

上级 ee75b95b
...@@ -71,29 +71,60 @@ def pool_2d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0), ...@@ -71,29 +71,60 @@ def pool_2d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0),
ds ds
*deprecated*, use parameter ws instead. *deprecated*, use parameter ws instead.
st st
*deprecated*, use parameter st instead. *deprecated*, use parameter stride instead.
padding padding
*deprecated*, use parameter pad instead. *deprecated*, use parameter pad instead.
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
if ws is not None:
raise Exception(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
else:
warnings.warn( warnings.warn(
"pool_2d() ds parameter is deprecated, please use ws", "DEPRECATION: the 'ds' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'ws'.",
stacklevel=2
)
ws = ds ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
if stride is not None:
raise Exception(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
else:
warnings.warn( warnings.warn(
"pool_2d() st parameter is deprecated, please use stride", "DEPRECATION: the 'st' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'stride'.",
stacklevel=2
)
stride = st stride = st
if padding is not None: if padding is not None:
if pad not in {None, (0, 0)}:
raise Exception(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
else:
warnings.warn( warnings.warn(
"pool_2d() padding parameter is deprecated, please use pad", "DEPRECATION: the 'padding' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'pad'.",
stacklevel=2
)
pad = padding pad = padding
if ws is None:
raise ValueError('pool_2d() ws parameter can not be None')
if input.ndim < 2: if input.ndim < 2:
raise NotImplementedError('pool_2d requires a dimension >= 2') raise NotImplementedError('pool_2d requires a dimension >= 2')
...@@ -156,22 +187,53 @@ def pool_3d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0, 0), ...@@ -156,22 +187,53 @@ def pool_3d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0, 0),
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
if ws is not None:
raise Exception(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
else:
warnings.warn( warnings.warn(
"pool_3d() ds parameter is deprecated, please use ws", "DEPRECATION: the 'ds' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'ws'.",
stacklevel=2
)
ws = ds ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
if stride is not None:
raise Exception(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
else:
warnings.warn( warnings.warn(
"pool_3d() st parameter is deprecated, please use stride", "DEPRECATION: the 'st' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'stride'.",
stacklevel=2
)
stride = st stride = st
if padding is not None: if padding is not None:
if pad not in {None, (0, 0, 0)}:
raise Exception(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
else:
warnings.warn( warnings.warn(
"pool_3d() padding parameter is deprecated, please use pad", "DEPRECATION: the 'padding' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'pad'.",
stacklevel=2
)
pad = padding pad = padding
if ws is None:
raise ValueError('pool_3d() ws parameter can not be None')
if input.ndim < 3: if input.ndim < 3:
raise NotImplementedError('pool_3d requires a dimension >= 3') raise NotImplementedError('pool_3d requires a dimension >= 3')
...@@ -220,14 +282,21 @@ class Pool(OpenMPOp): ...@@ -220,14 +282,21 @@ class Pool(OpenMPOp):
ndim : int ndim : int
The number of pooling dimensions N. The number of pooling dimensions N.
The default is 2. The default is 2.
ds
*deprecated*, use parameter ws instead.
st
*deprecated*, use parameter st instead.
padding
*deprecated*, use parameter pad instead.
""" """
__props__ = ('ignore_border', 'mode', 'ndim') __props__ = ('ignore_border', 'mode', 'ndim')
@staticmethod @staticmethod
def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None, ndim=2, def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None,
ds=None, st=None, padding=None): ndim=2, ds=None, st=None, padding=None):
""" """
Return the shape of the output from this op, for input of given Return the shape of the output from this op, for input of given
shape and flags. shape and flags.
...@@ -255,6 +324,12 @@ class Pool(OpenMPOp): ...@@ -255,6 +324,12 @@ class Pool(OpenMPOp):
ndim : int ndim : int
The number of pooling dimensions N. The number of pooling dimensions N.
The default is 2. The default is 2.
ds
*deprecated*, use parameter ws instead.
st
*deprecated*, use parameter st instead.
padding
*deprecated*, use parameter pad instead.
Returns Returns
------- -------
...@@ -266,22 +341,53 @@ class Pool(OpenMPOp): ...@@ -266,22 +341,53 @@ class Pool(OpenMPOp):
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
if ws is not None:
raise Exception(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
else:
warnings.warn( warnings.warn(
"Pool ds parameter is deprecated, please use ws", "DEPRECATION: the 'ds' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'ws'.",
stacklevel=2
)
ws = ds ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
if stride is not None:
raise Exception(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
else:
warnings.warn( warnings.warn(
"Pool st parameter is deprecated, please use stride", "DEPRECATION: the 'st' parameter is not going to exist"
stacklevel=2) " anymore as it is going to be replaced by the parameter"
" 'stride'.",
stacklevel=2
)
stride = st stride = st
if padding is not None: if padding is not None:
if pad is not None:
raise Exception(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
else:
warnings.warn( warnings.warn(
"Pool padding parameter is deprecated, please use pad", "DEPRECATION: the 'padding' parameter is not going to"
stacklevel=2) " exist anymore as it is going to be replaced by the"
" parameter 'pad'.",
stacklevel=2
)
pad = padding pad = padding
if ws is None:
raise ValueError('Pool ws parameter can not be None')
if ndim is None: if ndim is None:
ndim = 2 ndim = 2
...@@ -812,6 +918,12 @@ class PoolGrad(OpenMPOp): ...@@ -812,6 +918,12 @@ class PoolGrad(OpenMPOp):
ndim : int ndim : int
The number of pooling dimensions N. The number of pooling dimensions N.
The default is 2. The default is 2.
ds
*deprecated*, use parameter ws instead.
st
*deprecated*, use parameter st instead.
padding
*deprecated*, use parameter pad instead.
Returns Returns
------- -------
...@@ -824,22 +936,53 @@ class PoolGrad(OpenMPOp): ...@@ -824,22 +936,53 @@ class PoolGrad(OpenMPOp):
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
if ws is not None:
raise Exception(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
else:
warnings.warn( warnings.warn(
"PoolGrad ds parameter is deprecated, please use ws", "DEPRECATION: the 'ds' parameter in PoolGrad is not going"
stacklevel=2) " to exist anymore as it is going to be replaced by the"
" parameter 'ws'.",
stacklevel=2
)
ws = ds ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
if stride is not None:
raise Exception(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
else:
warnings.warn( warnings.warn(
"PoolGrad st parameter is deprecated, please use stride", "DEPRECATION: the 'st' parameter in PoolGrad is not going"
stacklevel=2) " to exist anymore as it is going to be replaced by the"
" parameter 'stride'.",
stacklevel=2
)
stride = st stride = st
if padding is not None: if padding is not None:
if pad is not (0, 0):
raise Exception(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
else:
warnings.warn( warnings.warn(
"PoolGrad padding parameter is deprecated, please use pad", "DEPRECATION: the 'padding' parameter in PoolGrad is not"
stacklevel=2) " going to exist anymore as it is going to be replaced"
" by the parameter 'pad'.",
stacklevel=2
)
pad = padding pad = padding
if ws is None:
raise ValueError('PoolGrad ws parameter can not be None')
if len(imgshape) < ndim: if len(imgshape) < ndim:
raise TypeError('imgshape must have at least {} dimensions'.format(ndim)) raise TypeError('imgshape must have at least {} dimensions'.format(ndim))
......
...@@ -915,6 +915,34 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -915,6 +915,34 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
mode=mode) mode=mode)
utt.verify_grad(mp, [imval], rng=rng) utt.verify_grad(mp, [imval], rng=rng)
def test_max_pool_3d_3D_deprecated_interface(self):
rng = numpy.random.RandomState(utt.fetch_seed())
maxpoolshps = ((1, 1, 1), (3, 2, 1))
imval = rng.rand(4, 5, 6)
images = tensor.dtensor3()
for maxpoolshp, ignore_border, mode in product(maxpoolshps,
[True, False],
['max', 'sum',
'average_inc_pad',
'average_exc_pad']):
# print 'maxpoolshp =', maxpoolshp
# print 'ignore_border =', ignore_border
numpy_output_val = self.numpy_max_pool_nd(imval, maxpoolshp,
ignore_border,
mode=mode)
output = pool_3d(input=images,
ds=maxpoolshp,
ignore_border=ignore_border,
mode=mode)
output_val = function([images], output)(imval)
utt.assert_allclose(output_val, numpy_output_val)
def mp(input):
return pool_3d(input, maxpoolshp, ignore_border,
mode=mode)
utt.verify_grad(mp, [imval], rng=rng)
def test_max_pool_2d_2D_same_size(self): def test_max_pool_2d_2D_same_size(self):
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
test_input_array = numpy.array([[[ test_input_array = numpy.array([[[
...@@ -1076,6 +1104,47 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -1076,6 +1104,47 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
utt.assert_allclose(var_y, fix_y) utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dx, fix_dx) utt.assert_allclose(var_dx, fix_dx)
def test_pooling_with_tensor_vars_deprecated_interface(self):
x = tensor.ftensor4()
window_size = tensor.ivector()
stride = tensor.ivector()
padding = tensor.ivector()
data = numpy.random.normal(0, 1, (1, 1, 5, 5)).astype('float32')
# checking variable params vs fixed params
for ignore_border in [True, False]:
for mode in ['max', 'sum', 'average_inc_pad', 'average_exc_pad']:
y = pool_2d(input=x,
ds=window_size,
ignore_border=ignore_border,
st=stride,
pad=None,
padding=padding,
mode=mode)
dx = theano.gradient.grad(y.sum(), x)
var_fct = theano.function([x, window_size, stride, padding],
[y, dx])
for ws in (4, 2, 5):
for st in (2, 3):
for pad in (0, 1):
if (pad > st or st > ws or
(pad != 0 and not ignore_border) or
(mode == 'average_exc_pad' and pad != 0)):
continue
y = pool_2d(input=x,
ds=(ws, ws),
ignore_border=ignore_border,
st=(st, st),
pad=(pad, pad),
mode=mode)
dx = theano.gradient.grad(y.sum(), x)
fix_fct = theano.function([x], [y, dx])
var_y, var_dx = var_fct(data, (ws, ws), (st, st),
(pad, pad))
fix_y, fix_dx = fix_fct(data)
utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dx, fix_dx)
def test_old_pool_interface(self): def test_old_pool_interface(self):
if sys.version_info[0] != 3: if sys.version_info[0] != 3:
# Only tested with python 3 because of pickling issues. # Only tested with python 3 because of pickling issues.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论