Unverified 提交 ba4fcbe3 authored 作者: Ian Schweer's avatar Ian Schweer 提交者: GitHub

Implement OpFromGraph in PyTorch backend (#956)

上级 3e55a209
...@@ -30,6 +30,7 @@ from pytensor.compile.mode import ( ...@@ -30,6 +30,7 @@ from pytensor.compile.mode import (
OPT_O3, OPT_O3,
OPT_STABILIZE, OPT_STABILIZE,
OPT_UNSAFE, OPT_UNSAFE,
PYTORCH,
AddDestroyHandler, AddDestroyHandler,
AddFeatureOptimizer, AddFeatureOptimizer,
Mode, Mode,
......
...@@ -3,7 +3,10 @@ from types import NoneType ...@@ -3,7 +3,10 @@ from types import NoneType
import numpy as np import numpy as np
import torch import torch
import torch.compiler
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
...@@ -150,6 +153,19 @@ def pytorch_funcify_MakeVector(op, **kwargs): ...@@ -150,6 +153,19 @@ def pytorch_funcify_MakeVector(op, **kwargs):
return makevector return makevector
@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)
# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)
@pytorch_funcify.register(TensorFromScalar) @pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs): def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x): def tensorfromscalar(x):
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import pytest import pytest
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_mode from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
...@@ -14,7 +15,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -14,7 +15,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector from pytensor.tensor.type import matrices, matrix, scalar, vector
torch = pytest.importorskip("torch") torch = pytest.importorskip("torch")
...@@ -301,3 +302,19 @@ def test_pytorch_MakeVector(): ...@@ -301,3 +302,19 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, []) compare_pytorch_and_py(x_fg, [])
def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论