提交 5165df44 authored 作者: James Bergstra's avatar James Bergstra

tensor.Sum - standardized self.axis to be a sorted tuple

上级 6cbe47a2
......@@ -795,10 +795,14 @@ class CAReduce(Op):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.")
self.scalar_op = scalar_op
if isinstance(axis, int):
self.axis = [axis]
else:
if axis is None:
self.axis = axis
elif isinstance(axis, int):
self.axis = (axis,)
else:
self.axis = list(set(axis))
self.axis.sort()
self.axis = tuple(self.axis)
self.ufunc = numpy.frompyfunc(scalar_op.impl, 2, 1)
# CAReduce output views input when reducing scalars
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论