提交 c03cf9a6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Unique: align with Python implementation

上级 ea267946
...@@ -13,6 +13,7 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -13,6 +13,7 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key, register_funcify_and_cache_key,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.npy_2_compat import old_np_unique
from pytensor.tensor import TensorVariable from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import ( from pytensor.tensor.extra_ops import (
Bartlett, Bartlett,
...@@ -241,10 +242,17 @@ def numba_funcify_Unique(op, node, **kwargs): ...@@ -241,10 +242,17 @@ def numba_funcify_Unique(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def unique(x): def unique(x):
with numba.objmode(ret=ret_sig): with numba.objmode(ret=ret_sig):
ret = np.unique(x, return_index, return_inverse, return_counts, axis) ret = old_np_unique(
x,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
)
return ret return ret
return unique cache_version = 1
return unique, cache_version
@register_funcify_and_cache_key(UnravelIndex) @register_funcify_and_cache_key(UnravelIndex)
......
...@@ -296,6 +296,14 @@ def test_Repeat(x, repeats, axis, exc): ...@@ -296,6 +296,14 @@ def test_Repeat(x, repeats, axis, exc):
True, True,
UserWarning, UserWarning,
), ),
(
(pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")),
None,
True,
True,
True,
UserWarning,
),
], ],
) )
def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论