提交 ed86ecc8 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Fix bug in numba impl of cumsum

上级 c89a48db
...@@ -43,6 +43,7 @@ def numba_funcify_CumOp(op, node, **kwargs): ...@@ -43,6 +43,7 @@ def numba_funcify_CumOp(op, node, **kwargs):
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}") raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis) reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
reaxis_first_inv = tuple(np.argsort(reaxis_first))
if mode == "add": if mode == "add":
...@@ -65,7 +66,7 @@ def numba_funcify_CumOp(op, node, **kwargs): ...@@ -65,7 +66,7 @@ def numba_funcify_CumOp(op, node, **kwargs):
for m in range(1, x.shape[axis]): for m in range(1, x.shape[axis]):
res[m] = res[m - 1] + x_axis_first[m] res[m] = res[m - 1] + x_axis_first[m]
return res.transpose(reaxis_first) return res.transpose(reaxis_first_inv)
else: else:
if ndim == 1: if ndim == 1:
......
...@@ -80,6 +80,13 @@ def test_BroadcastTo(x, shape): ...@@ -80,6 +80,13 @@ def test_BroadcastTo(x, shape):
1, 1,
"add", "add",
), ),
(
set_test_value(
at.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))
),
-1,
"add",
),
( (
set_test_value( set_test_value(
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论