提交 686ed878 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup sort Op dispatchers

上级 1f7a2686
...@@ -9,6 +9,15 @@ from pytensor.tensor.sort import ArgSortOp, SortOp ...@@ -9,6 +9,15 @@ from pytensor.tensor.sort import ArgSortOp, SortOp
@numba_funcify.register(SortOp) @numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs): def numba_funcify_SortOp(op, node, **kwargs):
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
@numba_njit @numba_njit
def sort_f(a, axis): def sort_f(a, axis):
axis = axis.item() axis = axis.item()
...@@ -19,41 +28,11 @@ def numba_funcify_SortOp(op, node, **kwargs): ...@@ -19,41 +28,11 @@ def numba_funcify_SortOp(op, node, **kwargs):
return a_sorted_swapped return a_sorted_swapped
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f return sort_f
@numba_funcify.register(ArgSortOp) @numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs): def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind kind = op.kind
if kind not in ["quicksort", "mergesort"]: if kind not in ["quicksort", "mergesort"]:
...@@ -66,4 +45,19 @@ def numba_funcify_ArgSortOp(op, node, **kwargs): ...@@ -66,4 +45,19 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
UserWarning, UserWarning,
) )
return argsort_f_kind(kind) @numba_njit
def argort_f(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(Y[(slice(None), *idx)], kind=kind)
result = np.swapaxes(result, 0, axis)
return result
return argort_f
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论