Unverified 提交 4ea96b27 authored 作者: Diego Sandoval's avatar Diego Sandoval 提交者: GitHub

Implemented Eye Op in PyTorch (#877)

上级 ca102983
......@@ -6,7 +6,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
@singledispatch
......@@ -100,3 +100,19 @@ def pytorch_funcify_Join(op, **kwargs):
return torch.cat(tensors, dim=axis)
return join
@pytorch_funcify.register(Eye)
def pytorch_funcify_eye(op, **kwargs):
torch_dtype = getattr(torch, op.dtype)
def eye(N, M, k):
major, minor = (M, N) if k > 0 else (N, M)
k_abs = torch.abs(k)
zeros = torch.zeros(N, M, dtype=torch_dtype)
if k_abs < major:
l_ones = torch.min(major - k_abs, minor)
return zeros.diagonal_scatter(torch.ones(l_ones, dtype=torch_dtype), k)
return zeros
return eye
......@@ -13,7 +13,7 @@ from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
......@@ -275,3 +275,22 @@ def test_pytorch_Join():
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
@pytest.mark.parametrize(
"dtype",
["int64", config.floatX],
)
def test_eye(dtype):
N = scalar("N", dtype="int64")
M = scalar("M", dtype="int64")
k = scalar("k", dtype="int64")
out = eye(N, M, k, dtype=dtype)
fn = function([N, M, k], out, mode=pytorch_mode)
for _N in range(1, 6):
for _M in range(1, 6):
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论