Unverified 提交 ee4d4f71 authored 作者: Diego Sandoval's avatar Diego Sandoval 提交者: GitHub

Implemented Sort/Argsort Ops in PyTorch (#897)

上级 a99d0670
......@@ -5,4 +5,5 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort
# isort: on
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.sort import ArgSortOp, SortOp
@pytorch_funcify.register(SortOp)
def pytorch_funcify_Sort(op, **kwargs):
stable = op.kind == "stable"
def sort(arr, axis):
sorted, _ = torch.sort(arr, dim=axis, stable=stable)
return sorted
return sort
@pytorch_funcify.register(ArgSortOp)
def pytorch_funcify_ArgSort(op, **kwargs):
stable = op.kind == "stable"
def argsort(arr, axis):
return torch.argsort(arr, dim=axis, stable=stable)
return argsort
import numpy as np
import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix
from pytensor.tensor.sort import argsort, sort
from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize("func", (sort, argsort))
@pytest.mark.parametrize(
"axis",
[
pytest.param(0),
pytest.param(1),
pytest.param(
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
),
],
)
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_pytorch_and_py(fgraph, [arr])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论