提交 4134881f authored 作者: HarshvirSandhu's avatar HarshvirSandhu 提交者: Ricardo Vieira

Implement indexing operations in pytorch

上级 1a1c62bb
...@@ -471,6 +471,7 @@ PYTORCH = Mode( ...@@ -471,6 +471,7 @@ PYTORCH = Mode(
"BlasOpt", "BlasOpt",
"fusion", "fusion",
"inplace", "inplace",
"local_uint_constant_indices",
], ],
), ),
) )
......
...@@ -7,7 +7,8 @@ import pytensor.link.pytorch.dispatch.scalar ...@@ -7,7 +7,8 @@ import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.nlinalg import pytensor.link.pytorch.dispatch.subtensor
# isort: on # isort: on
from functools import singledispatch from functools import singledispatch
from types import NoneType from types import NoneType
import numpy as np
import torch import torch
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
ARange,
Eye,
Join,
MakeVector,
TensorFromScalar,
)
@singledispatch @singledispatch
def pytorch_typify(data, dtype=None, **kwargs): def pytorch_typify(data, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
@pytorch_typify.register(np.ndarray)
@pytorch_typify.register(torch.Tensor)
def pytorch_typify_tensor(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype) return torch.as_tensor(data, dtype=dtype)
@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType) @pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs): @pytorch_typify.register(np.number)
return None def pytorch_typify_no_conversion_needed(data, **kwargs):
return data
@singledispatch @singledispatch
...@@ -132,3 +148,11 @@ def pytorch_funcify_MakeVector(op, **kwargs): ...@@ -132,3 +148,11 @@ def pytorch_funcify_MakeVector(op, **kwargs):
return torch.tensor(x, dtype=torch_dtype) return torch.tensor(x, dtype=torch_dtype)
return makevector return makevector
@pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x):
return torch.as_tensor(x)
return tensorfromscalar
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
def check_negative_steps(indices):
for index in indices:
if isinstance(index, slice):
if index.step is not None and index.step < 0:
raise NotImplementedError(
"Negative step sizes are not supported in Pytorch"
)
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]
return subtensor
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)
return makeslice
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
check_negative_steps(indices)
return x[indices]
return advsubtensor
@pytorch_funcify.register(IncSubtensor)
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
if op.set_instead_of_inc:
def set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y
return x
return set_subtensor
else:
def inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y
return x
return inc_subtensor
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)
if op.set_instead_of_inc:
def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y.type_as(x)
return x
return adv_set_subtensor
elif ignore_duplicates:
def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y.type_as(x)
return x
return adv_inc_subtensor_no_duplicates
else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if not inplace:
x = x.clone()
x.index_put_(indices, y.type_as(x), accumulate=True)
return x
return adv_inc_subtensor
...@@ -66,10 +66,10 @@ def compare_pytorch_and_py( ...@@ -66,10 +66,10 @@ def compare_pytorch_and_py(
py_res = pytensor_py_fn(*test_inputs) py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1: if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res): for pytorch_res_i, py_res_i in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p) assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
else: else:
assert_fn([pytorch_res[0].cpu()], py_res) assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])
return pytensor_torch_fn, pytorch_res return pytensor_torch_fn, pytorch_res
......
import contextlib
import numpy as np
import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import inc_subtensor, set_subtensor
from pytensor.tensor import subtensor as pt_subtensor
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_Subtensor():
shape = (3, 4, 5)
x_pt = pt.tensor("x", shape=shape, dtype="int")
x_np = np.arange(np.prod(shape)).reshape(shape)
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
# symbolic index
a_pt = ps.int64("a")
a_np = 1
out_pt = x_pt[a_pt, 2, a_pt:2]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
):
out_pt = x_pt[::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
def test_pytorch_AdvSubtensor():
shape = (3, 4, 5)
x_pt = pt.tensor("x", shape=shape, dtype="int")
x_np = np.arange(np.prod(shape)).reshape(shape)
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], 1:]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], None]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
a_pt = ps.int64("a")
a_np = 2
out_pt = x_pt[[1, a_pt], a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
# boolean indices
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
a_pt = pt.tensor3("a", dtype="bool")
a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)
out_pt = x_pt[a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
):
out_pt = x_pt[[1, 2], ::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_pytorch_and_py(out_fg, [x_np])
@pytest.mark.parametrize("subtensor_op", [set_subtensor, inc_subtensor])
def test_pytorch_IncSubtensor(subtensor_op):
x_pt = pt.tensor3("x")
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = subtensor_op(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Test different type update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
def inc_subtensor_ignore_duplicates(x, y):
return inc_subtensor(x, y, ignore_duplicates=True)
@pytest.mark.parametrize(
"advsubtensor_op", [set_subtensor, inc_subtensor, inc_subtensor_ignore_duplicates]
)
def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
rng = np.random.default_rng(42)
x_pt = pt.tensor3("x")
x_test = (np.arange(3 * 4 * 5) + 1).reshape((3, 4, 5)).astype(config.floatX)
st_pt = pt.as_tensor_variable(
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Repeated indices
out_pt = advsubtensor_op(x_pt[np.r_[0, 0]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Mixing advanced and basic indexing
if advsubtensor_op is inc_subtensor:
# PyTorch does not support `np.add.at` equivalent with slices
expectation = pytest.raises(NotImplementedError)
else:
expectation = contextlib.nullcontext()
st_pt = pt.as_tensor_variable(x_test[[0, 2], 0, :3])
out_pt = advsubtensor_op(x_pt[[0, 0], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
with expectation:
compare_pytorch_and_py(out_fg, [x_test])
# Test different dtype update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Boolean indices
out_pt = advsubtensor_op(x_pt[x_pt > 5], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论