提交 4134881f authored 作者: HarshvirSandhu's avatar HarshvirSandhu 提交者: Ricardo Vieira

Implement indexing operations in pytorch

上级 1a1c62bb
......@@ -471,6 +471,7 @@ PYTORCH = Mode(
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
],
),
)
......
......@@ -7,7 +7,8 @@ import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.subtensor
# isort: on
from functools import singledispatch
from types import NoneType
import numpy as np
import torch
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
ARange,
Eye,
Join,
MakeVector,
TensorFromScalar,
)
@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
def pytorch_typify(data, **kwargs):
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
@pytorch_typify.register(np.ndarray)
@pytorch_typify.register(torch.Tensor)
def pytorch_typify_tensor(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype)
@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs):
return None
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
return data
@singledispatch
......@@ -132,3 +148,11 @@ def pytorch_funcify_MakeVector(op, **kwargs):
return torch.tensor(x, dtype=torch_dtype)
return makevector
@pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x):
return torch.as_tensor(x)
return tensorfromscalar
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
def check_negative_steps(indices):
for index in indices:
if isinstance(index, slice):
if index.step is not None and index.step < 0:
raise NotImplementedError(
"Negative step sizes are not supported in Pytorch"
)
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]
return subtensor
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)
return makeslice
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
check_negative_steps(indices)
return x[indices]
return advsubtensor
@pytorch_funcify.register(IncSubtensor)
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
if op.set_instead_of_inc:
def set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y
return x
return set_subtensor
else:
def inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y
return x
return inc_subtensor
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)
if op.set_instead_of_inc:
def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y.type_as(x)
return x
return adv_set_subtensor
elif ignore_duplicates:
def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y.type_as(x)
return x
return adv_inc_subtensor_no_duplicates
else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if not inplace:
x = x.clone()
x.index_put_(indices, y.type_as(x), accumulate=True)
return x
return adv_inc_subtensor
......@@ -66,10 +66,10 @@ def compare_pytorch_and_py(
py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res):
assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])
return pytensor_torch_fn, pytorch_res
......
import contextlib
import numpy as np
import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import inc_subtensor, set_subtensor
from pytensor.tensor import subtensor as pt_subtensor
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_Subtensor():
shape = (3, 4, 5)
x_pt = pt.tensor("x", shape=shape, dtype="int")
x_np = np.arange(np.prod(shape)).reshape(shape)
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
# symbolic index
a_pt = ps.int64("a")
a_np = 1
out_pt = x_pt[a_pt, 2, a_pt:2]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
):
out_pt = x_pt[::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
def test_pytorch_AdvSubtensor():
shape = (3, 4, 5)
x_pt = pt.tensor("x", shape=shape, dtype="int")
x_np = np.arange(np.prod(shape)).reshape(shape)
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], 1:]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], None]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
a_pt = ps.int64("a")
a_np = 2
out_pt = x_pt[[1, a_pt], a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
# boolean indices
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
a_pt = pt.tensor3("a", dtype="bool")
a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)
out_pt = x_pt[a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
):
out_pt = x_pt[[1, 2], ::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_pytorch_and_py(out_fg, [x_np])
@pytest.mark.parametrize("subtensor_op", [set_subtensor, inc_subtensor])
def test_pytorch_IncSubtensor(subtensor_op):
x_pt = pt.tensor3("x")
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = subtensor_op(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Test different type update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
def inc_subtensor_ignore_duplicates(x, y):
return inc_subtensor(x, y, ignore_duplicates=True)
@pytest.mark.parametrize(
"advsubtensor_op", [set_subtensor, inc_subtensor, inc_subtensor_ignore_duplicates]
)
def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
rng = np.random.default_rng(42)
x_pt = pt.tensor3("x")
x_test = (np.arange(3 * 4 * 5) + 1).reshape((3, 4, 5)).astype(config.floatX)
st_pt = pt.as_tensor_variable(
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Repeated indices
out_pt = advsubtensor_op(x_pt[np.r_[0, 0]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Mixing advanced and basic indexing
if advsubtensor_op is inc_subtensor:
# PyTorch does not support `np.add.at` equivalent with slices
expectation = pytest.raises(NotImplementedError)
else:
expectation = contextlib.nullcontext()
st_pt = pt.as_tensor_variable(x_test[[0, 2], 0, :3])
out_pt = advsubtensor_op(x_pt[[0, 0], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
with expectation:
compare_pytorch_and_py(out_fg, [x_test])
# Test different dtype update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
# Boolean indices
out_pt = advsubtensor_op(x_pt[x_pt > 5], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论