提交 90eec589 authored 作者: Frederic Bastien's avatar Frederic Bastien

Rename outdim to ndim at a few places and add a test of the old name.

上级 9853268d
......@@ -5108,28 +5108,28 @@ def flatten(x, ndim=None, outdim=None):
the flattend variable with dimensionality of outdim
"""
if outdim is None and ndim is None:
outdim = 1
ndim = 1
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify outdim or ndim")
elif outdim is None:
outdim = ndim
# Any input variable can be flattened to have outdim of 1,
# even if it's a scalar. Otherwise, outdim must be positive
elif outdim is not None:
ndim = outdim
# Any input variable can be flattened to have ndim of 1,
# even if it's a scalar. Otherwise, ndim must be positive
# and smaller than x.ndim.
if outdim < 1 or (outdim > 1 and outdim > x.ndim):
raise ValueError('outdim %s out of bound [1, %d)'
% (outdim, x.ndim + 1))
if ndim < 1 or (ndim > 1 and ndim > x.ndim):
raise ValueError('ndim %s out of bound [1, %d)'
% (ndim, x.ndim + 1))
if outdim > 1:
dims = tuple(x.shape[:outdim - 1]) + (-1,)
if ndim > 1:
dims = tuple(x.shape[:ndim - 1]) + (-1,)
else:
dims = (-1,)
x_reshaped = x.reshape(dims)
bcast_kept_dims = x.broadcastable[:outdim - 1]
bcast_new_dim = python_all(x.broadcastable[outdim - 1:])
bcast_kept_dims = x.broadcastable[:ndim - 1]
bcast_new_dim = python_all(x.broadcastable[ndim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
x_reshaped = theano.tensor.addbroadcast(
x_reshaped, *filter(lambda i: broadcastable[i], range(outdim)))
x_reshaped, *filter(lambda i: broadcastable[i], range(ndim)))
return x_reshaped
......
......@@ -5530,7 +5530,7 @@ def test_flatten_scalar():
# utt.verify_grad(flatten, [a_val]) #TODO: fix verify_grd to work on scalars
def test_flatten_outdim1():
def test_flatten_ndim1():
a = dmatrix()
c = flatten(a, 1)
f = inplace_func([a], c)
......@@ -5543,7 +5543,7 @@ def test_flatten_outdim1():
utt.verify_grad(flatten, [a_val])
def test_flatten_outdim2():
def test_flatten_ndim2():
a = dmatrix()
c = flatten(a, 2)
f = inplace_func([a], c)
......@@ -5552,11 +5552,11 @@ def test_flatten_outdim2():
f = inplace_func([a], c)
assert np.all(f(a_val) == a_val)
flatten_2 = partial(flatten, outdim=2)
flatten_2 = partial(flatten, ndim=2)
utt.verify_grad(flatten_2, [a_val])
def test_flatten_outdim2_of_3():
def test_flatten_ndim2_of_3():
a = TensorType('float64', (False, False, False))()
c = flatten(a, 2)
f = inplace_func([a], c)
......@@ -5567,6 +5567,9 @@ def test_flatten_outdim2_of_3():
f = inplace_func([a], c)
assert np.all(f(a_val) == c_val)
flatten_2 = partial(flatten, ndim=2)
utt.verify_grad(flatten_2, [a_val])
# test outdim parameter name
flatten_2 = partial(flatten, outdim=2)
utt.verify_grad(flatten_2, [a_val])
......@@ -5576,27 +5579,27 @@ def test_flatten_broadcastable():
# that of the input
inp = TensorType('float64', (False, False, False, False))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, False, False, True))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, False, True))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, True, True))()
out = flatten(inp, outdim=2)
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, True)
inp = TensorType('float64', (True, False, True, True))()
out = flatten(inp, outdim=3)
out = flatten(inp, ndim=3)
assert out.broadcastable == (True, False, True)
def test_flatten_outdim_invalid():
def test_flatten_ndim_invalid():
a = dmatrix()
assert_raises(ValueError, flatten, a, 3)
assert_raises(ValueError, flatten, a, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论