提交 d6c185de authored 作者: abalkin's avatar abalkin

Fixed output rank for axis is not None case.

上级 0daa4491
......@@ -6833,7 +6833,14 @@ class Take(Op):
def make_node(self, a, indices):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
return gof.Apply(self, (a, indices), [a.type()])
if self.axis is None:
broadcastable = [False]
else:
broadcastable = (a.broadcastable[:self.axis] +
indices.broadcastable +
a.broadcastable[self.axis+1:])
return gof.Apply(self, (a, indices),
[TensorType(a.dtype, broadcastable)()])
def perform(self, node, inputs, outputs):
a, indices = inputs
......@@ -6854,14 +6861,9 @@ def take(a, indices, axis=None, mode='raise'):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
# Reuse advanced indexing in supported cases.
if axis is None:
if axis is None and mode == 'raise':
if indices.ndim == 1:
return a.flatten()[indices]
else:
if indices.ndim == 0:
item = [slice(None)] * a.ndim
item[axis] = indices
return a[tuple(item)]
return Take(axis, mode)(a, indices)
#########################
......
......@@ -7170,8 +7170,13 @@ class TestTensorInstanceMethods(unittest.TestCase):
def test_take(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.take([1,0,3]).eval({X: x}), x.take([1,0,3]))
indices = [1,0,3]
assert_array_equal(X.take(indices).eval({X: x}), x.take(indices))
indices = [1,0,1]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
if __name__ == '__main__':
t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论