提交 61cd924c authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

changes of pool.py and test_pool.py have been made based on the Pull request #5138.

上级 ee75b95b
...@@ -71,29 +71,60 @@ def pool_2d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0), ...@@ -71,29 +71,60 @@ def pool_2d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0),
ds ds
*deprecated*, use parameter ws instead. *deprecated*, use parameter ws instead.
st st
*deprecated*, use parameter st instead. *deprecated*, use parameter stride instead.
padding padding
*deprecated*, use parameter pad instead. *deprecated*, use parameter pad instead.
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
warnings.warn( if ws is not None:
"pool_2d() ds parameter is deprecated, please use ws", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'ws' and 'ds'."
ws = ds " Please provide a value only to 'ws'."
)
else:
warnings.warn(
"DEPRECATION: the 'ds' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'ws'.",
stacklevel=2
)
ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
warnings.warn( if stride is not None:
"pool_2d() st parameter is deprecated, please use stride", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'st and 'stride'."
stride = st " Please provide a value only to 'stride'."
)
else:
warnings.warn(
"DEPRECATION: the 'st' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'stride'.",
stacklevel=2
)
stride = st
if padding is not None: if padding is not None:
warnings.warn( if pad not in {None, (0, 0)}:
"pool_2d() padding parameter is deprecated, please use pad", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'padding' and pad."
pad = padding " Please provide a value only to pad."
if ws is None: )
raise ValueError('pool_2d() ws parameter can not be None') else:
warnings.warn(
"DEPRECATION: the 'padding' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'pad'.",
stacklevel=2
)
pad = padding
if input.ndim < 2: if input.ndim < 2:
raise NotImplementedError('pool_2d requires a dimension >= 2') raise NotImplementedError('pool_2d requires a dimension >= 2')
...@@ -156,22 +187,53 @@ def pool_3d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0, 0), ...@@ -156,22 +187,53 @@ def pool_3d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0, 0),
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
warnings.warn( if ws is not None:
"pool_3d() ds parameter is deprecated, please use ws", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'ws' and 'ds'."
ws = ds " Please provide a value only to 'ws'."
)
else:
warnings.warn(
"DEPRECATION: the 'ds' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'ws'.",
stacklevel=2
)
ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
warnings.warn( if stride is not None:
"pool_3d() st parameter is deprecated, please use stride", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'st and 'stride'."
stride = st " Please provide a value only to 'stride'."
)
else:
warnings.warn(
"DEPRECATION: the 'st' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'stride'.",
stacklevel=2
)
stride = st
if padding is not None: if padding is not None:
warnings.warn( if pad not in {None, (0, 0, 0)}:
"pool_3d() padding parameter is deprecated, please use pad", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'padding' and pad."
pad = padding " Please provide a value only to pad."
if ws is None: )
raise ValueError('pool_3d() ws parameter can not be None') else:
warnings.warn(
"DEPRECATION: the 'padding' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'pad'.",
stacklevel=2
)
pad = padding
if input.ndim < 3: if input.ndim < 3:
raise NotImplementedError('pool_3d requires a dimension >= 3') raise NotImplementedError('pool_3d requires a dimension >= 3')
...@@ -220,14 +282,21 @@ class Pool(OpenMPOp): ...@@ -220,14 +282,21 @@ class Pool(OpenMPOp):
ndim : int ndim : int
The number of pooling dimensions N. The number of pooling dimensions N.
The default is 2. The default is 2.
ds
*deprecated*, use parameter ws instead.
st
*deprecated*, use parameter st instead.
padding
*deprecated*, use parameter pad instead.
""" """
__props__ = ('ignore_border', 'mode', 'ndim') __props__ = ('ignore_border', 'mode', 'ndim')
@staticmethod @staticmethod
def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None, ndim=2, def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None,
ds=None, st=None, padding=None): ndim=2, ds=None, st=None, padding=None):
""" """
Return the shape of the output from this op, for input of given Return the shape of the output from this op, for input of given
shape and flags. shape and flags.
...@@ -255,6 +324,12 @@ class Pool(OpenMPOp): ...@@ -255,6 +324,12 @@ class Pool(OpenMPOp):
ndim : int ndim : int
The number of pooling dimensions N. The number of pooling dimensions N.
The default is 2. The default is 2.
ds
*deprecated*, use parameter ws instead.
st
*deprecated*, use parameter st instead.
padding
*deprecated*, use parameter pad instead.
Returns Returns
------- -------
...@@ -266,22 +341,53 @@ class Pool(OpenMPOp): ...@@ -266,22 +341,53 @@ class Pool(OpenMPOp):
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
warnings.warn( if ws is not None:
"Pool ds parameter is deprecated, please use ws", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'ws' and 'ds'."
ws = ds " Please provide a value only to 'ws'."
)
else:
warnings.warn(
"DEPRECATION: the 'ds' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'ws'.",
stacklevel=2
)
ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
warnings.warn( if stride is not None:
"Pool st parameter is deprecated, please use stride", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'st and 'stride'."
stride = st " Please provide a value only to 'stride'."
)
else:
warnings.warn(
"DEPRECATION: the 'st' parameter is not going to exist"
" anymore as it is going to be replaced by the parameter"
" 'stride'.",
stacklevel=2
)
stride = st
if padding is not None: if padding is not None:
warnings.warn( if pad is not None:
"Pool padding parameter is deprecated, please use pad", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'padding' and pad."
pad = padding " Please provide a value only to pad."
if ws is None: )
raise ValueError('Pool ws parameter can not be None') else:
warnings.warn(
"DEPRECATION: the 'padding' parameter is not going to"
" exist anymore as it is going to be replaced by the"
" parameter 'pad'.",
stacklevel=2
)
pad = padding
if ndim is None: if ndim is None:
ndim = 2 ndim = 2
...@@ -812,6 +918,12 @@ class PoolGrad(OpenMPOp): ...@@ -812,6 +918,12 @@ class PoolGrad(OpenMPOp):
ndim : int ndim : int
The number of pooling dimensions N. The number of pooling dimensions N.
The default is 2. The default is 2.
ds
*deprecated*, use parameter ws instead.
st
*deprecated*, use parameter st instead.
padding
*deprecated*, use parameter pad instead.
Returns Returns
------- -------
...@@ -824,22 +936,53 @@ class PoolGrad(OpenMPOp): ...@@ -824,22 +936,53 @@ class PoolGrad(OpenMPOp):
""" """
# check for deprecated parameter names # check for deprecated parameter names
if ds is not None: if ds is not None:
warnings.warn( if ws is not None:
"PoolGrad ds parameter is deprecated, please use ws", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'ws' and 'ds'."
ws = ds " Please provide a value only to 'ws'."
)
else:
warnings.warn(
"DEPRECATION: the 'ds' parameter in PoolGrad is not going"
" to exist anymore as it is going to be replaced by the"
" parameter 'ws'.",
stacklevel=2
)
ws = ds
elif ds is None and ws is None:
raise Exception(
"You must provide a tuple value for the window size."
)
if st is not None: if st is not None:
warnings.warn( if stride is not None:
"PoolGrad st parameter is deprecated, please use stride", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'st and 'stride'."
stride = st " Please provide a value only to 'stride'."
)
else:
warnings.warn(
"DEPRECATION: the 'st' parameter in PoolGrad is not going"
" to exist anymore as it is going to be replaced by the"
" parameter 'stride'.",
stacklevel=2
)
stride = st
if padding is not None: if padding is not None:
warnings.warn( if pad is not (0, 0):
"PoolGrad padding parameter is deprecated, please use pad", raise Exception(
stacklevel=2) "You can't provide a tuple value to both 'padding' and pad."
pad = padding " Please provide a value only to pad."
if ws is None: )
raise ValueError('PoolGrad ws parameter can not be None') else:
warnings.warn(
"DEPRECATION: the 'padding' parameter in PoolGrad is not"
" going to exist anymore as it is going to be replaced"
" by the parameter 'pad'.",
stacklevel=2
)
pad = padding
if len(imgshape) < ndim: if len(imgshape) < ndim:
raise TypeError('imgshape must have at least {} dimensions'.format(ndim)) raise TypeError('imgshape must have at least {} dimensions'.format(ndim))
......
...@@ -915,6 +915,34 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -915,6 +915,34 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
mode=mode) mode=mode)
utt.verify_grad(mp, [imval], rng=rng) utt.verify_grad(mp, [imval], rng=rng)
def test_max_pool_3d_3D_deprecated_interface(self):
rng = numpy.random.RandomState(utt.fetch_seed())
maxpoolshps = ((1, 1, 1), (3, 2, 1))
imval = rng.rand(4, 5, 6)
images = tensor.dtensor3()
for maxpoolshp, ignore_border, mode in product(maxpoolshps,
[True, False],
['max', 'sum',
'average_inc_pad',
'average_exc_pad']):
# print 'maxpoolshp =', maxpoolshp
# print 'ignore_border =', ignore_border
numpy_output_val = self.numpy_max_pool_nd(imval, maxpoolshp,
ignore_border,
mode=mode)
output = pool_3d(input=images,
ds=maxpoolshp,
ignore_border=ignore_border,
mode=mode)
output_val = function([images], output)(imval)
utt.assert_allclose(output_val, numpy_output_val)
def mp(input):
return pool_3d(input, maxpoolshp, ignore_border,
mode=mode)
utt.verify_grad(mp, [imval], rng=rng)
def test_max_pool_2d_2D_same_size(self): def test_max_pool_2d_2D_same_size(self):
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
test_input_array = numpy.array([[[ test_input_array = numpy.array([[[
...@@ -1076,6 +1104,47 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -1076,6 +1104,47 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
utt.assert_allclose(var_y, fix_y) utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dx, fix_dx) utt.assert_allclose(var_dx, fix_dx)
def test_pooling_with_tensor_vars_deprecated_interface(self):
x = tensor.ftensor4()
window_size = tensor.ivector()
stride = tensor.ivector()
padding = tensor.ivector()
data = numpy.random.normal(0, 1, (1, 1, 5, 5)).astype('float32')
# checking variable params vs fixed params
for ignore_border in [True, False]:
for mode in ['max', 'sum', 'average_inc_pad', 'average_exc_pad']:
y = pool_2d(input=x,
ds=window_size,
ignore_border=ignore_border,
st=stride,
pad=None,
padding=padding,
mode=mode)
dx = theano.gradient.grad(y.sum(), x)
var_fct = theano.function([x, window_size, stride, padding],
[y, dx])
for ws in (4, 2, 5):
for st in (2, 3):
for pad in (0, 1):
if (pad > st or st > ws or
(pad != 0 and not ignore_border) or
(mode == 'average_exc_pad' and pad != 0)):
continue
y = pool_2d(input=x,
ds=(ws, ws),
ignore_border=ignore_border,
st=(st, st),
pad=(pad, pad),
mode=mode)
dx = theano.gradient.grad(y.sum(), x)
fix_fct = theano.function([x], [y, dx])
var_y, var_dx = var_fct(data, (ws, ws), (st, st),
(pad, pad))
fix_y, fix_dx = fix_fct(data)
utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dx, fix_dx)
def test_old_pool_interface(self): def test_old_pool_interface(self):
if sys.version_info[0] != 3: if sys.version_info[0] != 3:
# Only tested with python 3 because of pickling issues. # Only tested with python 3 because of pickling issues.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论