提交 2948525f authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Ignore `axis` argument in numba `CumOp` when input is 1d

上级 ab5037e9
......@@ -49,7 +49,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x, axis=axis)
return np.cumsum(x)
else:
......@@ -73,7 +73,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x, axis=axis)
return np.cumprod(x)
else:
......
......@@ -58,6 +58,17 @@ def test_Bartlett(val):
1,
"mul",
),
# Regression tests for https://github.com/pymc-devs/pytensor/issues/1689
(
(pt.vector(), np.arange(6, dtype=config.floatX)),
0,
"add",
),
(
(pt.vector(), np.arange(6, dtype=config.floatX)),
0,
"mul",
),
],
)
def test_CumOp(val, axis, mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论