提交 4940bb33 authored 作者: Cesar Laurent's avatar Cesar Laurent

Numpy-like interface for stack.

上级 00f184d5
......@@ -583,6 +583,20 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
:type n_ones: int
:type n_ones: number of dimension to be added to `x`
.. function:: shape_padaxis(t, axis)
Reshape `t` by adding 1 at the dimension `axis`. Note that this new
dimension will be broadcastable. To make it non-broadcastable
see the :func:`unbroadcast`.
:type x: any TensorVariable (or compatible)
:param x: variable to be reshaped
:type axis: int
:param axis: axis where to add the new dimension to `x`
.. autofunction:: unbroadcast(x, *axes)
.. autofunction:: addbroadcast(x, *axes)
......@@ -678,6 +692,26 @@ Creating Tensor
except for the main diagonal, whose values are equal to one. The output
will have same dtype as `x`.
.. function:: stack(tensors, axis=0)
Warning: The interface stack(*tensors) is deprecated!
Return a Tensor representing for the arguments all stacked up into a single Tensor.
(of 1 rank greater).
:param tensors: a list or a tuple of one or more tensors of the same rank.
:param axis: the axis along which the tensors will be stacked.
:returns: A tensor such that rval[0] == tensors[0], rval[1] == tensors[1], etc.
>>> x0 = T.scalar()
>>> x1 = T.scalar()
>>> x2 = T.scalar()
>>> x = T.stack(x0, x1, x2)
>>> x.ndim # x is a vector of length 3.
1
.. function:: stack(*tensors)
Return a Tensor representing for the arguments all stacked up into a single Tensor.
......
......@@ -3962,6 +3962,19 @@ def shape_padright(t, n_ones=1):
return DimShuffle(_t.broadcastable, pattern)(_t)
@constructor
def shape_padaxis(t, axis):
"""Reshape `t` by adding 1 at the dimension `axis`.
See also: `shape_padleft`, `shape_padright` and `Dimshuffle`
"""
_t = as_tensor_variable(t)
pattern = [i for i in xrange(_t.type.ndim)]
pattern.insert(axis, 'x')
return DimShuffle(_t.broadcastable, pattern)(_t)
@constructor
def stack(*tensors):
"""Insert the arguments as slices into a tensor of 1 rank greater.
......@@ -3969,10 +3982,34 @@ def stack(*tensors):
The size in dimension 0 of the result will be equal to the number
of tensors passed.
Note: The interface stack(*tensors) is deprecated, you should use
stack(tensors, axis=0) insted.
:Parameters:
- `tensors` : list or tuple of tensors
A list of tensors to be stacked.
- `axis` : int
The index of the new axis.
"""
if len(tensors) == 0:
raise Exception('theano.tensor.stack(*tensors) must have at least'
' one parameter')
# Remove this when moving to the new interface: stack(tensors, axis=0)
# New numpy-like interface:
if isinstance(tensors[0], (list, tuple)):
if len(tensors) == 1:
axis = 0
else:
axis = tensors[1]
tensors = tensors[0]
# Deprecated interface:
else:
warnings.warn('stack(*tensors) interface is deprecated, use'
' stack(tensors, axis=0) instead.', stacklevel=3)
axis = 0
# If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
......@@ -3994,7 +4031,7 @@ def stack(*tensors):
tensors = list(map(as_tensor_variable, tensors))
dtype = scal.upcast(*[i.dtype for i in tensors])
return theano.tensor.opt.MakeVector(dtype)(*tensors)
return join(0, *[shape_padleft(t, 1) for t in tensors])
return join(axis, *[shape_padaxis(t, axis) for t in tensors])
@constructor
......
......@@ -3441,6 +3441,23 @@ class T_Join_and_Split(unittest.TestCase):
assert len([n for n in topo if isinstance(n, type(self.join_op))]) == 0
assert f.maker.fgraph.outputs[0].dtype == 'int64'
def test_stack_new_interface(self):
"""Test the new numpy-like interface: stack(tensors, axis=0)."""
a = tensor.imatrix('a')
b = tensor.imatrix('b')
s1 = stack(a, b)
s2 = stack([a, b])
f = function([a, b], [s1, s2], mode=self.mode)
v1, v2 = f([[1, 2]], [[3, 4]])
self.assertTrue(v1.shape == v2.shape)
self.assertTrue(numpy.all(v1 == v2))
s3 = stack([a, b], 1)
f = function([a, b], s3, mode=self.mode)
v3 = f([[1, 2]], [[3, 4]])
v4 = numpy.array([[[1, 2], [3, 4]]])
self.assertTrue(v3.shape == v4.shape)
self.assertTrue(numpy.all(v3 == v4))
def test_stack_hessian(self):
# Test the gradient of stack when used in hessian, see gh-1589
a = tensor.dvector('a')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论