提交 6c3e7578 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba implementation of CumOp when axis is None

上级 9a5deee0
import warnings
from typing import cast
import numba
import numpy as np
from pytensor import config
from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import (
Bartlett,
CumOp,
......@@ -30,21 +33,22 @@ def numba_funcify_Bartlett(op, **kwargs):
@numba_funcify.register(CumOp)
def numba_funcify_CumOp(op, node, **kwargs):
def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
axis = op.axis
mode = op.mode
ndim = node.outputs[0].ndim
ndim = cast(TensorVariable, node.outputs[0]).ndim
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
if axis is not None:
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
reaxis_first_inv = tuple(np.argsort(reaxis_first))
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
reaxis_first_inv = tuple(np.argsort(reaxis_first))
if mode == "add":
if ndim == 1:
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
......@@ -68,7 +72,7 @@ def numba_funcify_CumOp(op, node, **kwargs):
return res.transpose(reaxis_first_inv)
else:
if ndim == 1:
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
......
from collections.abc import Collection
from typing import Iterable, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import numpy as np
from numpy.core.multiarray import normalize_axis_index
......@@ -291,7 +291,7 @@ class CumOp(COp):
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
)
def __init__(self, axis=None, mode="add"):
def __init__(self, axis: Optional[int] = None, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
self.axis = axis
......
......@@ -619,7 +619,7 @@ class _tensor_py_operators:
)
@property
def ndim(self):
def ndim(self) -> int:
"""The rank of this tensor."""
return self.type.ndim
......
......@@ -67,6 +67,13 @@ def test_Bartlett(val):
1,
"add",
),
(
set_test_value(
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
None,
"add",
),
(
set_test_value(
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
......@@ -81,6 +88,13 @@ def test_Bartlett(val):
1,
"mul",
),
(
set_test_value(
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
None,
"mul",
),
],
)
def test_CumOp(val, axis, mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论