提交 3df1e8a3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5873 from nouiz/gh-5836

[Interface change] be more consistent, use ndim instead of outdim
......@@ -5099,7 +5099,7 @@ def is_flat(var, outdim=1):
return var.ndim == outdim
def flatten(x, outdim=1):
def flatten(x, ndim=None, outdim=None):
"""
Reshapes the variable x by keeping
the first outdim-1 dimension size(s) of x the same,
......@@ -5111,31 +5111,42 @@ def flatten(x, outdim=1):
x : theano.tensor.var.TensorVariable
the variable that should be reshaped.
outdim : int
ndim : int
the number of dimensions of the returned variable
Default 1.
outdim : int
DEPRECATED synonym for ndim
Returns
-------
theano.tensor.var.TensorVariable
the flattend variable with dimensionality of outdim
"""
# Any input variable can be flattened to have outdim of 1,
# even if it's a scalar. Otherwise, outdim must be positive
if outdim is None and ndim is None:
ndim = 1
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim")
elif outdim is not None:
warnings.warn(
"flatten outdim parameter is deprecated, use ndim instead.")
ndim = outdim
# Any input variable can be flattened to have ndim of 1,
# even if it's a scalar. Otherwise, ndim must be positive
# and smaller than x.ndim.
if outdim < 1 or (outdim > 1 and outdim > x.ndim):
raise ValueError('outdim %s out of bound [1, %d)'
% (outdim, x.ndim + 1))
if ndim < 1 or (ndim > 1 and ndim > x.ndim):
raise ValueError('ndim %s out of bound [1, %d)'
% (ndim, x.ndim + 1))
if outdim > 1:
dims = tuple(x.shape[:outdim - 1]) + (-1,)
if ndim > 1:
dims = tuple(x.shape[:ndim - 1]) + (-1,)
else:
dims = (-1,)
x_reshaped = x.reshape(dims)
bcast_kept_dims = x.broadcastable[:outdim - 1]
bcast_new_dim = python_all(x.broadcastable[outdim - 1:])
bcast_kept_dims = x.broadcastable[:ndim - 1]
bcast_new_dim = python_all(x.broadcastable[ndim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
x_reshaped = theano.tensor.addbroadcast(
x_reshaped, *filter(lambda i: broadcastable[i], range(outdim)))
x_reshaped, *filter(lambda i: broadcastable[i], range(ndim)))
return x_reshaped
......
......@@ -5530,7 +5530,7 @@ def test_flatten_scalar():
# utt.verify_grad(flatten, [a_val]) #TODO: fix verify_grd to work on scalars
def test_flatten_outdim1():
def test_flatten_ndim1():
a = dmatrix()
c = flatten(a, 1)
f = inplace_func([a], c)
......@@ -5543,7 +5543,7 @@ def test_flatten_outdim1():
utt.verify_grad(flatten, [a_val])
def test_flatten_outdim2():
def test_flatten_ndim2():
a = dmatrix()
c = flatten(a, 2)
f = inplace_func([a], c)
......@@ -5552,11 +5552,11 @@ def test_flatten_outdim2():
f = inplace_func([a], c)
assert np.all(f(a_val) == a_val)
flatten_2 = partial(flatten, outdim=2)
flatten_2 = partial(flatten, ndim=2)
utt.verify_grad(flatten_2, [a_val])
def test_flatten_outdim2_of_3():
def test_flatten_ndim2_of_3():
a = TensorType('float64', (False, False, False))()
c = flatten(a, 2)
f = inplace_func([a], c)
......@@ -5567,6 +5567,9 @@ def test_flatten_outdim2_of_3():
f = inplace_func([a], c)
assert np.all(f(a_val) == c_val)
flatten_2 = partial(flatten, ndim=2)
utt.verify_grad(flatten_2, [a_val])
# test outdim parameter name
flatten_2 = partial(flatten, outdim=2)
utt.verify_grad(flatten_2, [a_val])
......@@ -5576,27 +5579,27 @@ def test_flatten_broadcastable():
# that of the input
inp = TensorType('float64', (False, False, False, False))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, False, False, True))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, False, True))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, True, True))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, True)
inp = TensorType('float64', (True, False, True, True))()
out = flatten(inp, outdim=3)
out = flatten(inp, ndim=3)
assert out.broadcastable == (True, False, True)
def test_flatten_outdim_invalid():
def test_flatten_ndim_invalid():
a = dmatrix()
assert_raises(ValueError, flatten, a, 3)
assert_raises(ValueError, flatten, a, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论