提交 edb3d5d4 authored 作者: Cesar Laurent's avatar Cesar Laurent

Corrected axis verification.

上级 af1e9b37
...@@ -3975,6 +3975,13 @@ def shape_padaxis(t, axis): ...@@ -3975,6 +3975,13 @@ def shape_padaxis(t, axis):
""" """
_t = as_tensor_variable(t) _t = as_tensor_variable(t)
ndim = _t.ndim + 1
if not -ndim <= axis < ndim:
msg = 'axis {0} is out of bounds [-{1}, {1})'.format(axis, ndim)
raise IndexError(msg)
if axis < 0:
axis += ndim
pattern = [i for i in xrange(_t.type.ndim)] pattern = [i for i in xrange(_t.type.ndim)]
pattern.insert(axis, 'x') pattern.insert(axis, 'x')
return DimShuffle(_t.broadcastable, pattern)(_t) return DimShuffle(_t.broadcastable, pattern)(_t)
...@@ -4028,13 +4035,6 @@ def stack(*tensors, **kwargs): ...@@ -4028,13 +4035,6 @@ def stack(*tensors, **kwargs):
raise Exception('tensors is empty. You should at least provide one' raise Exception('tensors is empty. You should at least provide one'
' tensor to theano.tensor.stack(tensors, axis).') ' tensor to theano.tensor.stack(tensors, axis).')
ndim = tensors[0].ndim + 1
if not -ndim <= axis < ndim:
msg = 'axis {0} is out of bounds [-{1}, {1})'.format(axis, ndim)
raise IndexError(msg)
if axis < 0:
axis += ndim
# If all tensors are scalars of the same type, call make_vector. # If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts # It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
......
...@@ -3461,13 +3461,13 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3461,13 +3461,13 @@ class T_Join_and_Split(unittest.TestCase):
self.assertTrue(v3.shape == v4.shape) self.assertTrue(v3.shape == v4.shape)
self.assertTrue(numpy.all(v3 == v4)) self.assertTrue(numpy.all(v3 == v4))
# Testing negative axis # Testing negative axis
s = stack([a, b], axis=-1) s = stack([a, b], axis=-2)
f = function([a, b], s, mode=self.mode) f = function([a, b], s, mode=self.mode)
v1 = [[1, 2, 3], [4, 5, 6]] v1 = [[1, 2, 3], [4, 5, 6]]
v2 = [[7, 8, 9], [10, 11, 12]] v2 = [[7, 8, 9], [10, 11, 12]]
v = numpy.zeros((2, 3, 2)) v = numpy.zeros((2, 2, 3))
v[:,:,0] = v1 v[:,0,:] = v1
v[:,:,1] = v2 v[:,1,:] = v2
out = f(v1, v2) out = f(v1, v2)
self.assertTrue(v.shape == out.shape) self.assertTrue(v.shape == out.shape)
self.assertTrue(numpy.all(v == out)) self.assertTrue(numpy.all(v == out))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论