提交 3a264877 authored 作者: Sina Honari's avatar Sina Honari

improving is_flat, adding a test and docstring

上级 072b031f
......@@ -3341,6 +3341,14 @@ class GpuFlatten(gof.HideC, tensor.Reshape, GpuOp):
def gpu_flatten(x, outdim=1):
"""
Implement flatten on the gpu.
:param x: the variable that should be reshaped.
:type x: theano.tensor.var.TensorVariable
:param outdim: the number of dimensions of the returned variable
:type outdim: int
:returns: the flattend variable with dimensionality of outdim
"""
x = as_cuda_ndarray_variable(x)
if outdim > 1:
......
......@@ -4659,32 +4659,34 @@ class Flatten(Op):
def is_flat(node, outdim=1):
"""
Checks whether node's op is an instance of Reshape
and verifies the dimensionality of variable is correct.
"""
if not isinstance(node.op, theano.tensor.Reshape):
return False
verifies the dimensionality of variable is correct.
new_shape = node.inputs[1]
:param node: the theano node on which the dimensionality is checked.
:type node: theano.tensor.var.TensorVariable
# If new shape is defined by `MakeVector`, then number of inputs to
# `MakeVector` must be equal to `outdim`.
if new_shape.owner and \
isinstance(new_shape.owner.op, theano.tensor.opt.MakeVector):
return new_shape.ndim == 1 and len(new_shape.owner.inputs) == outdim
# `TensorConstant`
elif isinstance(new_shape, theano.tensor.TensorConstant):
return new_shape.ndim == 1 and new_shape.data.shape[0] == outdim
else:
raise NotImplemented()
:param outdim: the expected dimensionality of node.
:type outdim: int
:returns: the comparison result of node's dim
and the expected outdim.
"""
return node.ndim == outdim
def flatten(x, outdim=1):
"""
Reshapes the variable x by keeping
the first outdim-1 dimension(s) of x the same
the first outdim-1 dimension(s) of x the same,
and making the last dimension of x equal to
the multiplication of its remaining dimensions.
:param x: the theano variable that should be reshaped.
:type x: theano.tensor.var.TensorVariable
:param outdim: the number of dimensions of the returned variable
:type outdim: int
:returns: the flattend variable with dimensionality of outdim
"""
outdim = int(outdim)
# Any input variable can be flattened to have outdim of 1,
......
......@@ -5254,6 +5254,37 @@ def test_flatten_outdim_invalid():
pass
def test_is_flat():
"""
tests is_flat method for constant and symbolic variables,
as well as reshaped constant and symbolic variables on the
given outdim
"""
# Constant variable
assert tensor.is_flat(tensor.as_tensor_variable(numpy.zeros((10))))
assert tensor.is_flat(tensor.as_tensor_variable(numpy.zeros((10, 10, 10))),
outdim=3)
assert not tensor.is_flat(
tensor.as_tensor_variable(numpy.zeros((10, 10, 10))))
# Symbolic variable
assert tensor.is_flat(tensor.vector())
assert tensor.is_flat(tensor.tensor3(), outdim=3)
assert not tensor.is_flat(tensor.tensor3())
# Reshape with constant shape
X = tensor.tensor4()
assert tensor.is_flat(X.reshape((-1, )))
assert tensor.is_flat(X.reshape((10, 10, -1)), outdim=3)
assert not tensor.is_flat(X.reshape((10, 10, -1)))
# Reshape with symbolic shape
X = tensor.tensor4()
assert tensor.is_flat(X.reshape((tensor.iscalar(), )))
assert tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3), outdim=3)
assert not tensor.is_flat(X.reshape((tensor.iscalar(), ) * 3))
def test_tile():
def run_tile(x, x_, reps, use_symbolic_reps):
if use_symbolic_reps:
......
......@@ -5890,7 +5890,7 @@ def test_local_flatten_lift():
topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[:i-1])+(numpy.prod(x_np.shape[i-1:]),)
assert shape_out_np == out_np.shape
assert tensor.is_flat(topo[-2], outdim=i)
assert tensor.is_flat(topo[-2].outputs[0], outdim=i)
assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论