提交 2948525f authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Ignore `axis` argument in numba `CumOp` when input is 1d

上级 ab5037e9
...@@ -49,7 +49,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -49,7 +49,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def cumop(x): def cumop(x):
return np.cumsum(x, axis=axis) return np.cumsum(x)
else: else:
...@@ -73,7 +73,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -73,7 +73,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def cumop(x): def cumop(x):
return np.cumprod(x, axis=axis) return np.cumprod(x)
else: else:
......
...@@ -58,6 +58,17 @@ def test_Bartlett(val): ...@@ -58,6 +58,17 @@ def test_Bartlett(val):
1, 1,
"mul", "mul",
), ),
# Regression tests for https://github.com/pymc-devs/pytensor/issues/1689
(
(pt.vector(), np.arange(6, dtype=config.floatX)),
0,
"add",
),
(
(pt.vector(), np.arange(6, dtype=config.floatX)),
0,
"mul",
),
], ],
) )
def test_CumOp(val, axis, mode): def test_CumOp(val, axis, mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论