提交 12a92ee4 authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

pool.py and test_pool.py respect the flake8 norms. corrections have for the…

pool.py and test_pool.py respect the flake8 norms. corrections have for the test_pool have been made in order to include more deprecation examples.
上级 0439ec21
......@@ -79,7 +79,7 @@ def pool_2d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0),
# check for deprecated parameter names
if ds is not None:
if ws is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
......@@ -92,13 +92,13 @@ def pool_2d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0),
)
ws = ds
elif ds is None and ws is None:
raise Exception(
raise ValueError(
"You must provide a tuple value for the window size."
)
if st is not None:
if stride is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
......@@ -113,7 +113,7 @@ def pool_2d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0),
if padding is not None:
if pad not in {None, (0, 0)}:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
......@@ -188,7 +188,7 @@ def pool_3d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0, 0),
# check for deprecated parameter names
if ds is not None:
if ws is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
......@@ -201,13 +201,13 @@ def pool_3d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0, 0),
)
ws = ds
elif ds is None and ws is None:
raise Exception(
raise ValueError(
"You must provide a tuple value for the window size."
)
if st is not None:
if stride is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
......@@ -222,7 +222,7 @@ def pool_3d(input, ws=None, ignore_border=None, stride=None, pad=(0, 0, 0),
if padding is not None:
if pad not in {None, (0, 0, 0)}:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
......@@ -342,7 +342,7 @@ class Pool(OpenMPOp):
# check for deprecated parameter names
if ds is not None:
if ws is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
......@@ -355,13 +355,13 @@ class Pool(OpenMPOp):
)
ws = ds
elif ds is None and ws is None:
raise Exception(
raise ValueError(
"You must provide a tuple value for the window size."
)
if st is not None:
if stride is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
......@@ -375,11 +375,12 @@ class Pool(OpenMPOp):
stride = st
if padding is not None:
if pad is not None:
raise Exception(
zero_pad = (0,) * ndim
if pad not in {None, zero_pad}:
raise ValueError(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
)
else:
warnings.warn(
"DEPRECATION: the 'padding' parameter is not going to"
......@@ -937,7 +938,7 @@ class PoolGrad(OpenMPOp):
# check for deprecated parameter names
if ds is not None:
if ws is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'ws' and 'ds'."
" Please provide a value only to 'ws'."
)
......@@ -950,13 +951,13 @@ class PoolGrad(OpenMPOp):
)
ws = ds
elif ds is None and ws is None:
raise Exception(
raise ValueError(
"You must provide a tuple value for the window size."
)
if st is not None:
if stride is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'st and 'stride'."
" Please provide a value only to 'stride'."
)
......@@ -971,7 +972,7 @@ class PoolGrad(OpenMPOp):
if padding is not None:
if pad is not None:
raise Exception(
raise ValueError(
"You can't provide a tuple value to both 'padding' and pad."
" Please provide a value only to pad."
)
......@@ -2004,4 +2005,4 @@ class DownsampleFactorMaxGradGrad(OpenMPOp):
return ccode % locals()
def c_code_cache_version(self):
return (0, 4, self.openmp)
\ No newline at end of file
return (0, 4, self.openmp)
......@@ -934,6 +934,8 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
output = pool_3d(input=images,
ds=maxpoolshp,
ignore_border=ignore_border,
st=maxpoolshp,
padding=(0, 0, 0),
mode=mode)
output_val = function([images], output)(imval)
utt.assert_allclose(output_val, numpy_output_val)
......@@ -941,7 +943,6 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
def mp(input):
return pool_3d(input, maxpoolshp, ignore_border,
mode=mode)
utt.verify_grad(mp, [imval], rng=rng)
def test_max_pool_2d_2D_same_size(self):
rng = numpy.random.RandomState(utt.fetch_seed())
......@@ -1118,32 +1119,31 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
ds=window_size,
ignore_border=ignore_border,
st=stride,
pad=None,
padding=padding,
mode=mode)
dx = theano.gradient.grad(y.sum(), x)
var_fct = theano.function([x, window_size, stride, padding],
[y, dx])
for ws in (4, 2, 5):
for st in (2, 3):
for pad in (0, 1):
if (pad > st or st > ws or
(pad != 0 and not ignore_border) or
(mode == 'average_exc_pad' and pad != 0)):
continue
y = pool_2d(input=x,
ds=(ws, ws),
ignore_border=ignore_border,
st=(st, st),
pad=(pad, pad),
mode=mode)
dx = theano.gradient.grad(y.sum(), x)
fix_fct = theano.function([x], [y, dx])
var_y, var_dx = var_fct(data, (ws, ws), (st, st),
(pad, pad))
fix_y, fix_dx = fix_fct(data)
utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dx, fix_dx)
ws = 5
st = 3
pad = 1
if (pad > st or st > ws or
(pad != 0 and not ignore_border) or
(mode == 'average_exc_pad' and pad != 0)):
continue
y = pool_2d(input=x,
ds=(ws, ws),
ignore_border=ignore_border,
st=(st, st),
padding=(pad, pad),
mode=mode)
dx = theano.gradient.grad(y.sum(), x)
fix_fct = theano.function([x], [y, dx])
var_y, var_dx = var_fct(data, (ws, ws), (st, st),
(pad, pad))
fix_y, fix_dx = fix_fct(data)
utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dx, fix_dx)
def test_old_pool_interface(self):
if sys.version_info[0] != 3:
......@@ -1181,4 +1181,4 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论