提交 429ba6c9 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Fix numba impl of CumOp

上级 98be9c5f
...@@ -36,31 +36,57 @@ def numba_funcify_CumOp(op, node, **kwargs): ...@@ -36,31 +36,57 @@ def numba_funcify_CumOp(op, node, **kwargs):
mode = op.mode mode = op.mode
ndim = node.outputs[0].ndim ndim = node.outputs[0].ndim
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis) reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
if mode == "add": if mode == "add":
np_func = np.add
identity = 0 if ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
return np.cumsum(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
return x.astype(out_dtype)
x_axis_first = x.transpose(reaxis_first)
res = np.empty(x_axis_first.shape, dtype=out_dtype)
res[0] = x_axis_first[0]
for m in range(1, x.shape[axis]):
res[m] = res[m - 1] + x_axis_first[m]
return res.transpose(reaxis_first)
else: else:
np_func = np.multiply if ndim == 1:
identity = 1 @numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
return np.cumprod(x)
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) else:
def cumop(x): @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
out_dtype = x.dtype def cumop(x):
if x.shape[axis] < 2: out_dtype = x.dtype
return x.astype(out_dtype) if x.shape[axis] < 2:
return x.astype(out_dtype)
x_axis_first = x.transpose(reaxis_first) x_axis_first = x.transpose(reaxis_first)
res = np.empty(x_axis_first.shape, dtype=out_dtype) res = np.empty(x_axis_first.shape, dtype=out_dtype)
for m in numba.prange(x.shape[axis]): res[0] = x_axis_first[0]
if m == 0: for m in range(1, x.shape[axis]):
np_func(identity, x_axis_first[m], res[m]) res[m] = res[m - 1] * x_axis_first[m]
else:
np_func(res[m - 1], x_axis_first[m], res[m])
return res.transpose(reaxis_first) return res.transpose(reaxis_first)
return cumop return cumop
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论