提交 e8b1794f authored 作者: Frederic Bastien's avatar Frederic Bastien

make CAReduce work with negative axis.

上级 5a0d2fb5
......@@ -888,9 +888,22 @@ class CAReduce(Op):
axis = self.axis
if axis is None:
axis = range(len(input.type.broadcastable))
if any([a<0 for a in axis]):
axis2=[]
for a in self.axis:
if a<0:
axis2.append(a+input.type.ndim)
else:
axis2.append(a)
assert len(axis)==len(axis2)
axis = tuple(axis2)
op = self.__class__(self.scalar_op, axis)
else:
op = self
output = TensorType(dtype = self._output_dtype(input.type.dtype),
broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])()
return Apply(self, [input], [output])
return Apply(op, [input], [output])
def __getstate__(self):
d = copy(self.__dict__)
......
......@@ -159,8 +159,11 @@ class test_CAReduce(unittest.TestCase):
((5, 6), (0, 1)),
((5, 6), (0, )),
((5, 6), (1, )),
((5, 6), (-1, )),
((5, 6), (-2, )),
((5, 6), ()),
((2, 3, 4, 5), (0, 1, 3)),
((2, 3, 4, 5), (-2, -3)),
((5, 0), (0, )),
((5, 0), (1, )),
((), ())]:
......@@ -171,6 +174,15 @@ class test_CAReduce(unittest.TestCase):
xv = numpy.asarray(numpy.random.rand(*xsh))
zv = xv
numpy_raised = False
if len(tosum)>1 and any([a<0 for a in tosum]):
#In that case, we need to use the good order of axis in the reduction.
axis2 = []
for a in tosum:
if a<0: axis2.append(a+len(xsh))
else: axis2.append(a)
assert len(axis2)==len(tosum)
tosum = tuple(axis2)
if scalar_op == add:
for axis in reversed(sorted(tosum)):
zv = numpy.add.reduce(zv, axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论