提交 30b78913 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5974 from nouiz/outdim

outdim -> ndim leftover and the same to is_flat
...@@ -629,23 +629,23 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`. ...@@ -629,23 +629,23 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. autofunction:: patternbroadcast(x, broadcastable) .. autofunction:: patternbroadcast(x, broadcastable)
.. function:: flatten(x, outdim=1) .. function:: flatten(x, ndim=1)
Similar to :func:`reshape`, but the shape is inferred from the shape of `x`. Similar to :func:`reshape`, but the shape is inferred from the shape of `x`.
:param x: variable to be flattened :param x: variable to be flattened
:type x: any TensorVariable (or compatible) :type x: any TensorVariable (or compatible)
:type outdim: int :type ndim: int
:param outdim: the number of dimensions in the returned variable :param ndim: the number of dimensions in the returned variable
:rtype: variable with same dtype as `x` and `outdim` dimensions :rtype: variable with same dtype as `x` and `ndim` dimensions
:returns: variable with the same shape as `x` in the leading `outdim-1` :returns: variable with the same shape as `x` in the leading `ndim-1`
dimensions, but with all remaining dimensions of `x` collapsed into dimensions, but with all remaining dimensions of `x` collapsed into
the last dimension. the last dimension.
For example, if we flatten a tensor of shape (2, 3, 4, 5) with flatten(x, For example, if we flatten a tensor of shape (2, 3, 4, 5) with flatten(x,
outdim=2), then we'll have the same (2-1=1) leading dimensions (2,), and the ndim=2), then we'll have the same (2-1=1) leading dimensions (2,), and the
remaining dimensions are collapsed. So the output in this example would remaining dimensions are collapsed. So the output in this example would
have shape (2, 60). have shape (2, 60).
......
...@@ -749,7 +749,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -749,7 +749,7 @@ class PushOutScanOutput(gof.Optimizer):
# dot is usually faster on two large matrices than # dot is usually faster on two large matrices than
# a bunch of small ones # a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten( outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), outdim=2) outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1]) shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] =\ outer_dot_inputs[1] =\
......
...@@ -5073,7 +5073,7 @@ class Flatten(Op): ...@@ -5073,7 +5073,7 @@ class Flatten(Op):
""" % locals() """ % locals()
def is_flat(var, outdim=1): def is_flat(var, ndim=None, outdim=None):
""" """
Verifies the dimensionality of the var is equal to Verifies the dimensionality of the var is equal to
outdim. This method is usually called after flatten method on a outdim. This method is usually called after flatten method on a
...@@ -5096,7 +5096,15 @@ def is_flat(var, outdim=1): ...@@ -5096,7 +5096,15 @@ def is_flat(var, outdim=1):
the comparison result of var's dim the comparison result of var's dim
and the expected outdim. and the expected outdim.
""" """
return var.ndim == outdim if outdim is None and ndim is None:
ndim = 1
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim")
elif outdim is not None:
warnings.warn(
"flatten outdim parameter is deprecated, use ndim instead.")
ndim = outdim
return var.ndim == ndim
def flatten(x, ndim=None, outdim=None): def flatten(x, ndim=None, outdim=None):
......
...@@ -105,8 +105,8 @@ def conv2d(input, filters, image_shape=None, filter_shape=None, ...@@ -105,8 +105,8 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
" warn.signal_conv2d_interface to False", " warn.signal_conv2d_interface to False",
stacklevel=3) stacklevel=3)
output = tensor.flatten(output.T, outdim=2).T output = tensor.flatten(output.T, ndim=2).T
elif input.ndim == 2 or filters.ndim == 2: elif input.ndim == 2 or filters.ndim == 2:
output = tensor.flatten(output.T, outdim=3).T output = tensor.flatten(output.T, ndim=3).T
return output return output
...@@ -5613,25 +5613,25 @@ def test_is_flat(): ...@@ -5613,25 +5613,25 @@ def test_is_flat():
# Constant variable # Constant variable
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10)))) assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10))))
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10, 10, 10))), assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10, 10, 10))),
outdim=3) ndim=3)
assert not tensor.is_flat( assert not tensor.is_flat(
tensor.as_tensor_variable(np.zeros((10, 10, 10)))) tensor.as_tensor_variable(np.zeros((10, 10, 10))))
# Symbolic variable # Symbolic variable
assert tensor.is_flat(tensor.vector()) assert tensor.is_flat(tensor.vector())
assert tensor.is_flat(tensor.tensor3(), outdim=3) assert tensor.is_flat(tensor.tensor3(), ndim=3)
assert not tensor.is_flat(tensor.tensor3()) assert not tensor.is_flat(tensor.tensor3())
# Reshape with constant shape # Reshape with constant shape
X = tensor.tensor4() X = tensor.tensor4()
assert tensor.is_flat(X.reshape((-1, ))) assert tensor.is_flat(X.reshape((-1, )))
assert tensor.is_flat(X.reshape((10, 10, -1)), outdim=3) assert tensor.is_flat(X.reshape((10, 10, -1)), ndim=3)
assert not tensor.is_flat(X.reshape((10, 10, -1))) assert not tensor.is_flat(X.reshape((10, 10, -1)))
# Reshape with symbolic shape # Reshape with symbolic shape
X = tensor.tensor4() X = tensor.tensor4()
assert tensor.is_flat(X.reshape((tensor.iscalar(), ))) assert tensor.is_flat(X.reshape((tensor.iscalar(), )))
assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), outdim=3) assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), ndim=3)
assert not tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3)) assert not tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3))
......
...@@ -6222,7 +6222,7 @@ def test_local_flatten_lift(): ...@@ -6222,7 +6222,7 @@ def test_local_flatten_lift():
reshape_nodes = [n for n in topo if isinstance(n.op, tensor.Reshape)] reshape_nodes = [n for n in topo if isinstance(n.op, tensor.Reshape)]
assert (len(reshape_nodes) == 1 and assert (len(reshape_nodes) == 1 and
tensor.is_flat(reshape_nodes[0].outputs[0], outdim=i)) tensor.is_flat(reshape_nodes[0].outputs[0], ndim=i))
assert isinstance(topo[-1].op, tensor.Elemwise) assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论