Unverified 提交 4c784086 authored 作者: Diego Sandoval's avatar Diego Sandoval 提交者: GitHub

Implemented Repeat and Unique Ops in PyTorch (#890)

上级 a6b9585e
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.extra_ops import CumOp
from pytensor.tensor.extra_ops import CumOp, Repeat, Unique
@pytorch_funcify.register(CumOp)
......@@ -21,3 +21,38 @@ def pytorch_funcify_Cumop(op, **kwargs):
return torch.cumprod(x, dim=dim)
return cumop
@pytorch_funcify.register(Repeat)
def pytorch_funcify_Repeat(op, **kwargs):
axis = op.axis
def repeat(x, repeats):
return x.repeat_interleave(repeats, dim=axis)
return repeat
@pytorch_funcify.register(Unique)
def pytorch_funcify_Unique(op, **kwargs):
return_index = op.return_index
if return_index:
# TODO: evaluate whether is worth implementing this param
# (see https://github.com/pytorch/pytorch/issues/36748)
raise NotImplementedError("return_index is not implemented for pytorch")
axis = op.axis
return_inverse = op.return_inverse
return_counts = op.return_counts
def unique(x):
return torch.unique(
x,
sorted=True,
return_inverse=return_inverse,
return_counts=return_counts,
dim=axis,
)
return unique
......@@ -41,3 +41,61 @@ def test_pytorch_CumOp(axis, dtype):
out = pt.cumprod(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])
@pytest.mark.parametrize(
"axis, repeats",
[
(0, (1, 2, 3)),
(1, (3, 3)),
pytest.param(
None,
3,
marks=pytest.mark.xfail(reason="Reshape not implemented"),
),
],
)
def test_pytorch_Repeat(axis, repeats):
a = pt.matrix("a", dtype="float64")
test_value = np.arange(6, dtype="float64").reshape((3, 2))
out = pt.repeat(a, repeats, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_pytorch_Unique_axis(axis):
a = pt.matrix("a", dtype="float64")
test_value = np.array(
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
)
out = pt.unique(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])
@pytest.mark.parametrize("return_inverse", [False, True])
@pytest.mark.parametrize("return_counts", [False, True])
@pytest.mark.parametrize(
"return_index",
(False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))),
)
def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
a = pt.matrix("a", dtype="float64")
test_value = np.array(
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
)
out = pt.unique(
a,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=0,
)
fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out])
compare_pytorch_and_py(fgraph, [test_value])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论