提交 3df1e8a3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5873 from nouiz/gh-5836

[Interface change] be more consistent, use ndim instead of outdim
...@@ -5099,7 +5099,7 @@ def is_flat(var, outdim=1): ...@@ -5099,7 +5099,7 @@ def is_flat(var, outdim=1):
return var.ndim == outdim return var.ndim == outdim
def flatten(x, outdim=1): def flatten(x, ndim=None, outdim=None):
""" """
Reshapes the variable x by keeping Reshapes the variable x by keeping
the first outdim-1 dimension size(s) of x the same, the first outdim-1 dimension size(s) of x the same,
...@@ -5111,31 +5111,42 @@ def flatten(x, outdim=1): ...@@ -5111,31 +5111,42 @@ def flatten(x, outdim=1):
x : theano.tensor.var.TensorVariable x : theano.tensor.var.TensorVariable
the variable that should be reshaped. the variable that should be reshaped.
outdim : int ndim : int
the number of dimensions of the returned variable the number of dimensions of the returned variable
Default 1.
outdim : int
DEPRECATED synonym for ndim
Returns Returns
------- -------
theano.tensor.var.TensorVariable theano.tensor.var.TensorVariable
the flattend variable with dimensionality of outdim the flattend variable with dimensionality of outdim
""" """
# Any input variable can be flattened to have outdim of 1, if outdim is None and ndim is None:
# even if it's a scalar. Otherwise, outdim must be positive ndim = 1
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim")
elif outdim is not None:
warnings.warn(
"flatten outdim parameter is deprecated, use ndim instead.")
ndim = outdim
# Any input variable can be flattened to have ndim of 1,
# even if it's a scalar. Otherwise, ndim must be positive
# and smaller than x.ndim. # and smaller than x.ndim.
if outdim < 1 or (outdim > 1 and outdim > x.ndim): if ndim < 1 or (ndim > 1 and ndim > x.ndim):
raise ValueError('outdim %s out of bound [1, %d)' raise ValueError('ndim %s out of bound [1, %d)'
% (outdim, x.ndim + 1)) % (ndim, x.ndim + 1))
if outdim > 1: if ndim > 1:
dims = tuple(x.shape[:outdim - 1]) + (-1,) dims = tuple(x.shape[:ndim - 1]) + (-1,)
else: else:
dims = (-1,) dims = (-1,)
x_reshaped = x.reshape(dims) x_reshaped = x.reshape(dims)
bcast_kept_dims = x.broadcastable[:outdim - 1] bcast_kept_dims = x.broadcastable[:ndim - 1]
bcast_new_dim = python_all(x.broadcastable[outdim - 1:]) bcast_new_dim = python_all(x.broadcastable[ndim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,) broadcastable = bcast_kept_dims + (bcast_new_dim,)
x_reshaped = theano.tensor.addbroadcast( x_reshaped = theano.tensor.addbroadcast(
x_reshaped, *filter(lambda i: broadcastable[i], range(outdim))) x_reshaped, *filter(lambda i: broadcastable[i], range(ndim)))
return x_reshaped return x_reshaped
......
...@@ -5530,7 +5530,7 @@ def test_flatten_scalar(): ...@@ -5530,7 +5530,7 @@ def test_flatten_scalar():
# utt.verify_grad(flatten, [a_val]) #TODO: fix verify_grd to work on scalars # utt.verify_grad(flatten, [a_val]) #TODO: fix verify_grd to work on scalars
def test_flatten_outdim1(): def test_flatten_ndim1():
a = dmatrix() a = dmatrix()
c = flatten(a, 1) c = flatten(a, 1)
f = inplace_func([a], c) f = inplace_func([a], c)
...@@ -5543,7 +5543,7 @@ def test_flatten_outdim1(): ...@@ -5543,7 +5543,7 @@ def test_flatten_outdim1():
utt.verify_grad(flatten, [a_val]) utt.verify_grad(flatten, [a_val])
def test_flatten_outdim2(): def test_flatten_ndim2():
a = dmatrix() a = dmatrix()
c = flatten(a, 2) c = flatten(a, 2)
f = inplace_func([a], c) f = inplace_func([a], c)
...@@ -5552,11 +5552,11 @@ def test_flatten_outdim2(): ...@@ -5552,11 +5552,11 @@ def test_flatten_outdim2():
f = inplace_func([a], c) f = inplace_func([a], c)
assert np.all(f(a_val) == a_val) assert np.all(f(a_val) == a_val)
flatten_2 = partial(flatten, outdim=2) flatten_2 = partial(flatten, ndim=2)
utt.verify_grad(flatten_2, [a_val]) utt.verify_grad(flatten_2, [a_val])
def test_flatten_outdim2_of_3(): def test_flatten_ndim2_of_3():
a = TensorType('float64', (False, False, False))() a = TensorType('float64', (False, False, False))()
c = flatten(a, 2) c = flatten(a, 2)
f = inplace_func([a], c) f = inplace_func([a], c)
...@@ -5567,6 +5567,9 @@ def test_flatten_outdim2_of_3(): ...@@ -5567,6 +5567,9 @@ def test_flatten_outdim2_of_3():
f = inplace_func([a], c) f = inplace_func([a], c)
assert np.all(f(a_val) == c_val) assert np.all(f(a_val) == c_val)
flatten_2 = partial(flatten, ndim=2)
utt.verify_grad(flatten_2, [a_val])
# test outdim parameter name
flatten_2 = partial(flatten, outdim=2) flatten_2 = partial(flatten, outdim=2)
utt.verify_grad(flatten_2, [a_val]) utt.verify_grad(flatten_2, [a_val])
...@@ -5576,27 +5579,27 @@ def test_flatten_broadcastable(): ...@@ -5576,27 +5579,27 @@ def test_flatten_broadcastable():
# that of the input # that of the input
inp = TensorType('float64', (False, False, False, False))() inp = TensorType('float64', (False, False, False, False))()
out = flatten(inp, outdim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False) assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, False, False, True))() inp = TensorType('float64', (False, False, False, True))()
out = flatten(inp, outdim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False) assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, False, True))() inp = TensorType('float64', (False, True, False, True))()
out = flatten(inp, outdim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False) assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, True, True))() inp = TensorType('float64', (False, True, True, True))()
out = flatten(inp, outdim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, True) assert out.broadcastable == (False, True)
inp = TensorType('float64', (True, False, True, True))() inp = TensorType('float64', (True, False, True, True))()
out = flatten(inp, outdim=3) out = flatten(inp, ndim=3)
assert out.broadcastable == (True, False, True) assert out.broadcastable == (True, False, True)
def test_flatten_outdim_invalid(): def test_flatten_ndim_invalid():
a = dmatrix() a = dmatrix()
assert_raises(ValueError, flatten, a, 3) assert_raises(ValueError, flatten, a, 3)
assert_raises(ValueError, flatten, a, 0) assert_raises(ValueError, flatten, a, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论