Unverified 提交 426931b0 authored 作者: Diego Sandoval's avatar Diego Sandoval 提交者: GitHub

Implements shape Ops and MakeVector in PyTorch (#926)

* Implements shape and MakeVector Ops in PyTorch - Shape - Shape_i - Reshape - SpecifyShape - Unbroadcast - MakeVector
上级 6b8df2cc
...@@ -6,4 +6,5 @@ import pytensor.link.pytorch.dispatch.scalar ...@@ -6,4 +6,5 @@ import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.shape
# isort: on # isort: on
from functools import singledispatch from functools import singledispatch
from types import NoneType
import torch import torch
...@@ -6,7 +7,7 @@ from pytensor.compile.ops import DeepCopyOp ...@@ -6,7 +7,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
@singledispatch @singledispatch
...@@ -15,6 +16,11 @@ def pytorch_typify(data, dtype=None, **kwargs): ...@@ -15,6 +16,11 @@ def pytorch_typify(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype) return torch.as_tensor(data, dtype=dtype)
@pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs):
return None
@singledispatch @singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs): def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`.""" """Create a PyTorch compatible function from an PyTensor `Op`."""
...@@ -116,3 +122,13 @@ def pytorch_funcify_eye(op, **kwargs): ...@@ -116,3 +122,13 @@ def pytorch_funcify_eye(op, **kwargs):
return zeros return zeros
return eye return eye
@pytorch_funcify.register(MakeVector)
def pytorch_funcify_MakeVector(op, **kwargs):
torch_dtype = getattr(torch, op.dtype)
def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)
return makevector
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
return reshape
@pytorch_funcify.register(Shape)
def pytorch_funcify_Shape(op, **kwargs):
def shape(x):
return x.shape
return shape
@pytorch_funcify.register(Shape_i)
def pytorch_funcify_Shape_i(op, **kwargs):
i = op.i
def shape_i(x):
return torch.tensor(x.shape[i])
return shape_i
@pytorch_funcify.register(SpecifyShape)
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape):
if expected is None:
continue
if actual != expected:
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
return x
return specifyshape
@pytorch_funcify.register(Unbroadcast)
def pytorch_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x
return unbroadcast
...@@ -294,3 +294,10 @@ def test_eye(dtype): ...@@ -294,3 +294,10 @@ def test_eye(dtype):
for _M in range(1, 6): for _M in range(1, 6):
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k)) np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))
def test_pytorch_MakeVector():
x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [])
...@@ -43,18 +43,7 @@ def test_pytorch_CumOp(axis, dtype): ...@@ -43,18 +43,7 @@ def test_pytorch_CumOp(axis, dtype):
compare_pytorch_and_py(fgraph, [test_value]) compare_pytorch_and_py(fgraph, [test_value])
@pytest.mark.parametrize( @pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
"axis, repeats",
[
(0, (1, 2, 3)),
(1, (3, 3)),
pytest.param(
None,
3,
marks=pytest.mark.xfail(reason="Reshape not implemented"),
),
],
)
def test_pytorch_Repeat(axis, repeats): def test_pytorch_Repeat(axis, repeats):
a = pt.matrix("a", dtype="float64") a = pt.matrix("a", dtype="float64")
......
import numpy as np
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
def test_pytorch_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x])
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_pytorch_and_py(
x_fg,
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)
def test_pytorch_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_pytorch_Reshape_dynamic():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [])
...@@ -8,16 +8,7 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py ...@@ -8,16 +8,7 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize("func", (sort, argsort)) @pytest.mark.parametrize("func", (sort, argsort))
@pytest.mark.parametrize( @pytest.mark.parametrize("axis", [0, 1, None])
"axis",
[
pytest.param(0),
pytest.param(1),
pytest.param(
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
),
],
)
def test_sort(func, axis): def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64") x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis) out = func(x, axis=axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论