提交 30b78913 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5974 from nouiz/outdim

outdim -> ndim leftover and the same to is_flat
......@@ -629,23 +629,23 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. autofunction:: patternbroadcast(x, broadcastable)
.. function:: flatten(x, outdim=1)
.. function:: flatten(x, ndim=1)
Similar to :func:`reshape`, but the shape is inferred from the shape of `x`.
:param x: variable to be flattened
:type x: any TensorVariable (or compatible)
:type outdim: int
:param outdim: the number of dimensions in the returned variable
:type ndim: int
:param ndim: the number of dimensions in the returned variable
:rtype: variable with same dtype as `x` and `outdim` dimensions
:returns: variable with the same shape as `x` in the leading `outdim-1`
:rtype: variable with same dtype as `x` and `ndim` dimensions
:returns: variable with the same shape as `x` in the leading `ndim-1`
dimensions, but with all remaining dimensions of `x` collapsed into
the last dimension.
For example, if we flatten a tensor of shape (2, 3, 4, 5) with flatten(x,
outdim=2), then we'll have the same (2-1=1) leading dimensions (2,), and the
ndim=2), then we'll have the same (2-1=1) leading dimensions (2,), and the
remaining dimensions are collapsed. So the output in this example would
have shape (2, 60).
......
......@@ -749,7 +749,7 @@ class PushOutScanOutput(gof.Optimizer):
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), outdim=2)
outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] =\
......
......@@ -5073,7 +5073,7 @@ class Flatten(Op):
""" % locals()
def is_flat(var, outdim=1):
def is_flat(var, ndim=None, outdim=None):
"""
Verifies the dimensionality of the var is equal to
outdim. This method is usually called after flatten method on a
......@@ -5096,7 +5096,15 @@ def is_flat(var, outdim=1):
the comparison result of var's dim
and the expected outdim.
"""
return var.ndim == outdim
if outdim is None and ndim is None:
ndim = 1
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim")
elif outdim is not None:
warnings.warn(
"flatten outdim parameter is deprecated, use ndim instead.")
ndim = outdim
return var.ndim == ndim
def flatten(x, ndim=None, outdim=None):
......
......@@ -105,8 +105,8 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
" warn.signal_conv2d_interface to False",
stacklevel=3)
output = tensor.flatten(output.T, outdim=2).T
output = tensor.flatten(output.T, ndim=2).T
elif input.ndim == 2 or filters.ndim == 2:
output = tensor.flatten(output.T, outdim=3).T
output = tensor.flatten(output.T, ndim=3).T
return output
......@@ -5613,25 +5613,25 @@ def test_is_flat():
# Constant variable
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10))))
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10, 10, 10))),
outdim=3)
ndim=3)
assert not tensor.is_flat(
tensor.as_tensor_variable(np.zeros((10, 10, 10))))
# Symbolic variable
assert tensor.is_flat(tensor.vector())
assert tensor.is_flat(tensor.tensor3(), outdim=3)
assert tensor.is_flat(tensor.tensor3(), ndim=3)
assert not tensor.is_flat(tensor.tensor3())
# Reshape with constant shape
X = tensor.tensor4()
assert tensor.is_flat(X.reshape((-1, )))
assert tensor.is_flat(X.reshape((10, 10, -1)), outdim=3)
assert tensor.is_flat(X.reshape((10, 10, -1)), ndim=3)
assert not tensor.is_flat(X.reshape((10, 10, -1)))
# Reshape with symbolic shape
X = tensor.tensor4()
assert tensor.is_flat(X.reshape((tensor.iscalar(), )))
assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), outdim=3)
assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), ndim=3)
assert not tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3))
......
......@@ -6222,7 +6222,7 @@ def test_local_flatten_lift():
reshape_nodes = [n for n in topo if isinstance(n.op, tensor.Reshape)]
assert (len(reshape_nodes) == 1 and
tensor.is_flat(reshape_nodes[0].outputs[0], outdim=i))
tensor.is_flat(reshape_nodes[0].outputs[0], ndim=i))
assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论