Unverified 提交 e57e25bf authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Pytorch support for Join and Careduce Ops (#869)

上级 df769f6c
......@@ -6,7 +6,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join
@singledispatch
......@@ -89,3 +89,14 @@ def pytorch_funcify_arange(op, **kwargs):
return torch.arange(start, stop, step, dtype=dtype)
return arange
@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
return torch.cat(tensors, dim=axis)
return join
......@@ -2,6 +2,7 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
......@@ -37,6 +38,69 @@ def pytorch_funcify_DimShuffle(op, **kwargs):
return dimshuffle
@pytorch_funcify.register(Sum)
def pytorch_funcify_sum(op, **kwargs):
def torch_sum(x):
return torch.sum(x, dim=op.axis)
return torch_sum
@pytorch_funcify.register(All)
def pytorch_funcify_all(op, **kwargs):
def torch_all(x):
return torch.all(x, dim=op.axis)
return torch_all
@pytorch_funcify.register(Prod)
def pytorch_funcify_prod(op, **kwargs):
def torch_prod(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.prod(x, dim=d)
return x
else:
return torch.prod(x.flatten(), dim=0)
return torch_prod
@pytorch_funcify.register(Any)
def pytorch_funcify_any(op, **kwargs):
def torch_any(x):
return torch.any(x, dim=op.axis)
return torch_any
@pytorch_funcify.register(Max)
def pytorch_funcify_max(op, **kwargs):
def torch_max(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.max(x, dim=d).values
return x
else:
return torch.max(x.flatten(), dim=0).values
return torch_max
@pytorch_funcify.register(Min)
def pytorch_funcify_min(op, **kwargs):
def torch_min(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.min(x, dim=d).values
return x
else:
return torch.min(x.flatten(), dim=0).values
return torch_min
@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
......
......@@ -4,6 +4,7 @@ from functools import partial
import numpy as np
import pytest
import pytensor.tensor.basic as ptb
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
......@@ -13,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector
from pytensor.tensor.type import matrix, scalar, vector
torch = pytest.importorskip("torch")
......@@ -235,3 +236,42 @@ def test_arange():
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)
def test_pytorch_Join():
a = matrix("a")
b = matrix("b")
x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)
x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
......@@ -2,11 +2,12 @@ import numpy as np
import pytest
import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -57,6 +58,46 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])
@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)])
def test_pytorch_careduce(fn, axis):
a_pt = tensor3("a")
test_value = np.array(
[
[
[1, 1, 1, 1],
[2, 2, 2, 2],
],
[
[3, 3, 3, 3],
[
4,
4,
4,
4,
],
],
]
).astype(config.floatX)
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value])
@pytest.mark.parametrize("fn", [ptm.any, ptm.all])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
def test_pytorch_any_all(fn, axis):
a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value])
@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论