提交 556778f0 authored 作者: Frederic Bastien's avatar Frederic Bastien

Also deprecate/rename outdim to ndim in fct is_flat

上级 adef67e2
......@@ -5073,7 +5073,7 @@ class Flatten(Op):
""" % locals()
def is_flat(var, outdim=1):
def is_flat(var, ndim=None, outdim=None):
"""
Verifies the dimensionality of the var is equal to
outdim. This method is usually called after flatten method on a
......@@ -5096,7 +5096,15 @@ def is_flat(var, outdim=1):
the comparison result of var's dim
and the expected outdim.
"""
return var.ndim == outdim
if outdim is None and ndim is None:
ndim = 1
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim")
elif outdim is not None:
warnings.warn(
"flatten outdim parameter is deprecated, use ndim instead.")
ndim = outdim
return var.ndim == ndim
def flatten(x, ndim=None, outdim=None):
......
......@@ -5613,25 +5613,25 @@ def test_is_flat():
# Constant variable
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10))))
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10, 10, 10))),
outdim=3)
ndim=3)
assert not tensor.is_flat(
tensor.as_tensor_variable(np.zeros((10, 10, 10))))
# Symbolic variable
assert tensor.is_flat(tensor.vector())
assert tensor.is_flat(tensor.tensor3(), outdim=3)
assert tensor.is_flat(tensor.tensor3(), ndim=3)
assert not tensor.is_flat(tensor.tensor3())
# Reshape with constant shape
X = tensor.tensor4()
assert tensor.is_flat(X.reshape((-1, )))
assert tensor.is_flat(X.reshape((10, 10, -1)), outdim=3)
assert tensor.is_flat(X.reshape((10, 10, -1)), ndim=3)
assert not tensor.is_flat(X.reshape((10, 10, -1)))
# Reshape with symbolic shape
X = tensor.tensor4()
assert tensor.is_flat(X.reshape((tensor.iscalar(), )))
assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), outdim=3)
assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), ndim=3)
assert not tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3))
......
......@@ -6222,7 +6222,7 @@ def test_local_flatten_lift():
reshape_nodes = [n for n in topo if isinstance(n.op, tensor.Reshape)]
assert (len(reshape_nodes) == 1 and
tensor.is_flat(reshape_nodes[0].outputs[0], outdim=i))
tensor.is_flat(reshape_nodes[0].outputs[0], ndim=i))
assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论