提交 686ed878 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup sort Op dispatchers

上级 1f7a2686
......@@ -9,6 +9,15 @@ from pytensor.tensor.sort import ArgSortOp, SortOp
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
@numba_njit
def sort_f(a, axis):
axis = axis.item()
......@@ -19,41 +28,11 @@ def numba_funcify_SortOp(op, node, **kwargs):
return a_sorted_swapped
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
......@@ -66,4 +45,19 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
UserWarning,
)
return argsort_f_kind(kind)
@numba_njit
def argort_f(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(Y[(slice(None), *idx)], kind=kind)
result = np.swapaxes(result, 0, axis)
return result
return argort_f
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论