提交 af1e9b37 authored 作者: Cesar Laurent's avatar Cesar Laurent

Added negative axis support.

上级 575c7a97
......@@ -3984,7 +3984,7 @@ def shape_padaxis(t, axis):
def stack(*tensors, **kwargs):
"""Insert the arguments as slices into a tensor of 1 rank greater.
The size in dimension 0 of the result will be equal to the number
The size in dimension `axis` of the result will be equal to the number
of tensors passed.
Note: The interface stack(*tensors) is deprecated, you should use
......@@ -4028,6 +4028,13 @@ def stack(*tensors, **kwargs):
raise Exception('tensors is empty. You should at least provide one'
' tensor to theano.tensor.stack(tensors, axis).')
ndim = tensors[0].ndim + 1
if not -ndim <= axis < ndim:
msg = 'axis {0} is out of bounds [-{1}, {1})'.format(axis, ndim)
raise IndexError(msg)
if axis < 0:
axis += ndim
# If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
......
......@@ -3443,6 +3443,7 @@ class T_Join_and_Split(unittest.TestCase):
def test_stack_new_interface(self):
"""Test the new numpy-like interface: stack(tensors, axis=0)."""
# Testing against old interface
warnings.simplefilter('always', DeprecationWarning)
a = tensor.imatrix('a')
b = tensor.imatrix('b')
......@@ -3452,12 +3453,28 @@ class T_Join_and_Split(unittest.TestCase):
v1, v2 = f([[1, 2]], [[3, 4]])
self.assertTrue(v1.shape == v2.shape)
self.assertTrue(numpy.all(v1 == v2))
# Testing axis parameter
s3 = stack([a, b], 1)
f = function([a, b], s3, mode=self.mode)
v3 = f([[1, 2]], [[3, 4]])
v4 = numpy.array([[[1, 2], [3, 4]]])
self.assertTrue(v3.shape == v4.shape)
self.assertTrue(numpy.all(v3 == v4))
# Testing negative axis
s = stack([a, b], axis=-1)
f = function([a, b], s, mode=self.mode)
v1 = [[1, 2, 3], [4, 5, 6]]
v2 = [[7, 8, 9], [10, 11, 12]]
v = numpy.zeros((2, 3, 2))
v[:,:,0] = v1
v[:,:,1] = v2
out = f(v1, v2)
self.assertTrue(v.shape == out.shape)
self.assertTrue(numpy.all(v == out))
# Testing out-of-bounds axis
self.assertRaises(IndexError, stack, [a, b], 4)
self.assertRaises(IndexError, stack, [a, b], -4)
# Testing depreciation warning
with warnings.catch_warnings(record=True) as w:
s = stack(a, b)
assert len(w) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论