提交 edb3d5d4 authored 作者: Cesar Laurent's avatar Cesar Laurent

Corrected axis verification.

上级 af1e9b37
......@@ -3975,6 +3975,13 @@ def shape_padaxis(t, axis):
"""
_t = as_tensor_variable(t)
ndim = _t.ndim + 1
if not -ndim <= axis < ndim:
msg = 'axis {0} is out of bounds [-{1}, {1})'.format(axis, ndim)
raise IndexError(msg)
if axis < 0:
axis += ndim
pattern = [i for i in xrange(_t.type.ndim)]
pattern.insert(axis, 'x')
return DimShuffle(_t.broadcastable, pattern)(_t)
......@@ -4028,13 +4035,6 @@ def stack(*tensors, **kwargs):
raise Exception('tensors is empty. You should at least provide one'
' tensor to theano.tensor.stack(tensors, axis).')
ndim = tensors[0].ndim + 1
if not -ndim <= axis < ndim:
msg = 'axis {0} is out of bounds [-{1}, {1})'.format(axis, ndim)
raise IndexError(msg)
if axis < 0:
axis += ndim
# If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
......
......@@ -3461,13 +3461,13 @@ class T_Join_and_Split(unittest.TestCase):
self.assertTrue(v3.shape == v4.shape)
self.assertTrue(numpy.all(v3 == v4))
# Testing negative axis
s = stack([a, b], axis=-1)
s = stack([a, b], axis=-2)
f = function([a, b], s, mode=self.mode)
v1 = [[1, 2, 3], [4, 5, 6]]
v2 = [[7, 8, 9], [10, 11, 12]]
v = numpy.zeros((2, 3, 2))
v[:,:,0] = v1
v[:,:,1] = v2
v = numpy.zeros((2, 2, 3))
v[:,0,:] = v1
v[:,1,:] = v2
out = f(v1, v2)
self.assertTrue(v.shape == out.shape)
self.assertTrue(numpy.all(v == out))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论