提交 686ed878 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup sort Op dispatchers

上级 1f7a2686
......@@ -9,6 +9,15 @@ from pytensor.tensor.sort import ArgSortOp, SortOp
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
@numba_njit
def sort_f(a, axis):
axis = axis.item()
......@@ -19,23 +28,25 @@ def numba_funcify_SortOp(op, node, **kwargs):
return a_sorted_swapped
if op.kind != "quicksort":
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
kind = "quicksort"
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
f'Numba function argsort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
def argort_f(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
......@@ -44,26 +55,9 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result[(slice(None), *idx)] = np.argsort(Y[(slice(None), *idx)], kind=kind)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
kind = "quicksort"
warnings.warn(
(
f'Numba function argsort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return argsort_f_kind(kind)
return argort_f
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论