Unverified 提交 e57e25bf authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Pytorch support for Join and Careduce Ops (#869)

上级 df769f6c
...@@ -6,7 +6,7 @@ from pytensor.compile.ops import DeepCopyOp ...@@ -6,7 +6,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join
@singledispatch @singledispatch
...@@ -89,3 +89,14 @@ def pytorch_funcify_arange(op, **kwargs): ...@@ -89,3 +89,14 @@ def pytorch_funcify_arange(op, **kwargs):
return torch.arange(start, stop, step, dtype=dtype) return torch.arange(start, stop, step, dtype=dtype)
return arange return arange
@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
return torch.cat(tensors, dim=axis)
return join
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
...@@ -37,6 +38,69 @@ def pytorch_funcify_DimShuffle(op, **kwargs): ...@@ -37,6 +38,69 @@ def pytorch_funcify_DimShuffle(op, **kwargs):
return dimshuffle return dimshuffle
@pytorch_funcify.register(Sum)
def pytorch_funcify_sum(op, **kwargs):
def torch_sum(x):
return torch.sum(x, dim=op.axis)
return torch_sum
@pytorch_funcify.register(All)
def pytorch_funcify_all(op, **kwargs):
def torch_all(x):
return torch.all(x, dim=op.axis)
return torch_all
@pytorch_funcify.register(Prod)
def pytorch_funcify_prod(op, **kwargs):
def torch_prod(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.prod(x, dim=d)
return x
else:
return torch.prod(x.flatten(), dim=0)
return torch_prod
@pytorch_funcify.register(Any)
def pytorch_funcify_any(op, **kwargs):
def torch_any(x):
return torch.any(x, dim=op.axis)
return torch_any
@pytorch_funcify.register(Max)
def pytorch_funcify_max(op, **kwargs):
def torch_max(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.max(x, dim=d).values
return x
else:
return torch.max(x.flatten(), dim=0).values
return torch_max
@pytorch_funcify.register(Min)
def pytorch_funcify_min(op, **kwargs):
def torch_min(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.min(x, dim=d).values
return x
else:
return torch.min(x.flatten(), dim=0).values
return torch_min
@pytorch_funcify.register(Softmax) @pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs): def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis axis = op.axis
......
...@@ -4,6 +4,7 @@ from functools import partial ...@@ -4,6 +4,7 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import pytensor.tensor.basic as ptb
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_mode from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
...@@ -13,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -13,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector from pytensor.tensor.type import matrix, scalar, vector
torch = pytest.importorskip("torch") torch = pytest.importorskip("torch")
...@@ -235,3 +236,42 @@ def test_arange(): ...@@ -235,3 +236,42 @@ def test_arange():
FunctionGraph([start, stop, step], [out]), FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)], [np.array(1), np.array(10), np.array(2)],
) )
def test_pytorch_Join():
a = matrix("a")
b = matrix("b")
x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)
x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
...@@ -2,11 +2,12 @@ import numpy as np ...@@ -2,11 +2,12 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -57,6 +58,46 @@ def test_pytorch_elemwise(): ...@@ -57,6 +58,46 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]]) compare_pytorch_and_py(fg, [[0.9, 0.9]])
@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)])
def test_pytorch_careduce(fn, axis):
a_pt = tensor3("a")
test_value = np.array(
[
[
[1, 1, 1, 1],
[2, 2, 2, 2],
],
[
[3, 3, 3, 3],
[
4,
4,
4,
4,
],
],
]
).astype(config.floatX)
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value])
@pytest.mark.parametrize("fn", [ptm.any, ptm.all])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
def test_pytorch_any_all(fn, axis):
a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value])
@pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype): def test_softmax(axis, dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论