提交 556778f0 authored 作者: Frederic Bastien's avatar Frederic Bastien

Also deprecate/rename outdim to ndim in fct is_flat

上级 adef67e2
...@@ -5073,7 +5073,7 @@ class Flatten(Op): ...@@ -5073,7 +5073,7 @@ class Flatten(Op):
""" % locals() """ % locals()
def is_flat(var, outdim=1): def is_flat(var, ndim=None, outdim=None):
""" """
Verifies the dimensionality of the var is equal to Verifies the dimensionality of the var is equal to
outdim. This method is usually called after flatten method on a outdim. This method is usually called after flatten method on a
...@@ -5096,7 +5096,15 @@ def is_flat(var, outdim=1): ...@@ -5096,7 +5096,15 @@ def is_flat(var, outdim=1):
the comparison result of var's dim the comparison result of var's dim
and the expected outdim. and the expected outdim.
""" """
return var.ndim == outdim if outdim is None and ndim is None:
ndim = 1
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim")
elif outdim is not None:
warnings.warn(
"flatten outdim parameter is deprecated, use ndim instead.")
ndim = outdim
return var.ndim == ndim
def flatten(x, ndim=None, outdim=None): def flatten(x, ndim=None, outdim=None):
......
...@@ -5613,25 +5613,25 @@ def test_is_flat(): ...@@ -5613,25 +5613,25 @@ def test_is_flat():
# Constant variable # Constant variable
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10)))) assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10))))
assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10, 10, 10))), assert tensor.is_flat(tensor.as_tensor_variable(np.zeros((10, 10, 10))),
outdim=3) ndim=3)
assert not tensor.is_flat( assert not tensor.is_flat(
tensor.as_tensor_variable(np.zeros((10, 10, 10)))) tensor.as_tensor_variable(np.zeros((10, 10, 10))))
# Symbolic variable # Symbolic variable
assert tensor.is_flat(tensor.vector()) assert tensor.is_flat(tensor.vector())
assert tensor.is_flat(tensor.tensor3(), outdim=3) assert tensor.is_flat(tensor.tensor3(), ndim=3)
assert not tensor.is_flat(tensor.tensor3()) assert not tensor.is_flat(tensor.tensor3())
# Reshape with constant shape # Reshape with constant shape
X = tensor.tensor4() X = tensor.tensor4()
assert tensor.is_flat(X.reshape((-1, ))) assert tensor.is_flat(X.reshape((-1, )))
assert tensor.is_flat(X.reshape((10, 10, -1)), outdim=3) assert tensor.is_flat(X.reshape((10, 10, -1)), ndim=3)
assert not tensor.is_flat(X.reshape((10, 10, -1))) assert not tensor.is_flat(X.reshape((10, 10, -1)))
# Reshape with symbolic shape # Reshape with symbolic shape
X = tensor.tensor4() X = tensor.tensor4()
assert tensor.is_flat(X.reshape((tensor.iscalar(), ))) assert tensor.is_flat(X.reshape((tensor.iscalar(), )))
assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), outdim=3) assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), ndim=3)
assert not tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3)) assert not tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3))
......
...@@ -6222,7 +6222,7 @@ def test_local_flatten_lift(): ...@@ -6222,7 +6222,7 @@ def test_local_flatten_lift():
reshape_nodes = [n for n in topo if isinstance(n.op, tensor.Reshape)] reshape_nodes = [n for n in topo if isinstance(n.op, tensor.Reshape)]
assert (len(reshape_nodes) == 1 and assert (len(reshape_nodes) == 1 and
tensor.is_flat(reshape_nodes[0].outputs[0], outdim=i)) tensor.is_flat(reshape_nodes[0].outputs[0], ndim=i))
assert isinstance(topo[-1].op, tensor.Elemwise) assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论