提交 7300a687 authored 作者: Ian Schweer's avatar Ian Schweer 提交者: Ricardo Vieira

Track generated torch files for torch compiler

上级 4b41e092
...@@ -54,14 +54,16 @@ def pytorch_funcify_FunctionGraph( ...@@ -54,14 +54,16 @@ def pytorch_funcify_FunctionGraph(
fgraph, fgraph,
node=None, node=None,
fgraph_name="pytorch_funcified_fgraph", fgraph_name="pytorch_funcified_fgraph",
conversion_func=pytorch_funcify,
**kwargs, **kwargs,
): ):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python( return fgraph_to_python(
fgraph, fgraph,
pytorch_funcify, conversion_func,
type_conversion_fn=pytorch_typify, type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name, fgraph_name=fgraph_name,
**kwargs, **built_kwargs,
) )
...@@ -173,11 +175,8 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs): ...@@ -173,11 +175,8 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
# Apply inner rewrites # Apply inner rewrites
PYTORCH.optimizer(op.fgraph) PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions return fgraph_fn
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)
@pytorch_funcify.register(TensorFromScalar) @pytorch_funcify.register(TensorFromScalar)
......
import torch import torch
import torch.compiler
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.link.pytorch.dispatch import pytorch_funcify from pytensor.link.pytorch.dispatch import pytorch_funcify
...@@ -11,12 +10,13 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): ...@@ -11,12 +10,13 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
batched_dims = op.batch_ndim(node) batched_dims = op.batch_ndim(node)
core_node = op._create_dummy_core_node(node.inputs) core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1) inner_func = pytorch_funcify(
core_fgraph, squeeze_output=len(node.outputs) == 1, **kwargs
)
for _ in range(batched_dims): for _ in range(batched_dims):
inner_func = torch.vmap(inner_func) inner_func = torch.vmap(inner_func)
@torch.compiler.disable(recursive=False)
def batcher(*inputs): def batcher(*inputs):
op._check_runtime_broadcast(node, inputs) op._check_runtime_broadcast(node, inputs)
# broadcast on batched_dims # broadcast on batched_dims
......
import copy
from typing import Any from typing import Any
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator
class PytorchLinker(JITLinker): class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile.""" """A `Linker` that compiles NumPy-based operations using torch.compile."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
def input_filter(self, inp: Any) -> Any: def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify from pytensor.link.pytorch.dispatch import pytorch_typify
...@@ -18,14 +24,68 @@ class PytorchLinker(JITLinker): ...@@ -18,14 +24,68 @@ class PytorchLinker(JITLinker):
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify from pytensor.link.pytorch.dispatch import pytorch_funcify
# We want to have globally unique names
# across the entire pytensor graph, not
# just the subgraph
generator = unique_name_generator(["torch_linker"])
# Ensure that torch is aware of the generated
# code so we can compile without graph breaks
def conversion_func_register(*args, **kwargs):
functor = pytorch_funcify(*args, **kwargs)
name = kwargs["unique_name"](functor)
self.gen_functors.append((f"_{name}", functor))
return functor
built_kwargs = {
"unique_name": generator,
"conversion_func": conversion_func_register,
**kwargs,
}
return pytorch_funcify( return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
) )
def jit_compile(self, fn): def jit_compile(self, fn):
import torch import torch
return torch.compile(fn) class wrapper:
"""
Pytorch would fail compiling our method when trying
to resolve some of the methods returned from dispatch
calls. We want to be careful to not leak the methods,
so this class just holds them and provisions the expected
location accordingly
https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
"""
def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = copy.copy(gen_functors)
def __call__(self, *args, **kwargs):
import pytensor.link.utils
# set attrs
for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn)
res = self.fn(*args, **kwargs)
# unset attrs
for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])
return res
def __del__(self):
del self.gen_functors
res = wrapper(fn, self.gen_functors)
self.gen_functors = []
return res
def create_thunk_inputs(self, storage_map): def create_thunk_inputs(self, storage_map):
thunk_inputs = [] thunk_inputs = []
......
...@@ -675,6 +675,7 @@ def fgraph_to_python( ...@@ -675,6 +675,7 @@ def fgraph_to_python(
local_env: dict[Any, Any] | None = None, local_env: dict[Any, Any] | None = None,
get_name_for_object: Callable[[Any], str] = get_name_for_object, get_name_for_object: Callable[[Any], str] = get_name_for_object,
squeeze_output: bool = False, squeeze_output: bool = False,
unique_name: Callable | None = None,
**kwargs, **kwargs,
) -> Callable: ) -> Callable:
"""Convert a `FunctionGraph` into a regular Python function. """Convert a `FunctionGraph` into a regular Python function.
...@@ -706,6 +707,8 @@ def fgraph_to_python( ...@@ -706,6 +707,8 @@ def fgraph_to_python(
get_name_for_object get_name_for_object
A function used to provide names for the objects referenced within the A function used to provide names for the objects referenced within the
generated function. generated function.
unique_name
A function to make random function names for generated code
squeeze_output squeeze_output
If the `FunctionGraph` has only one output and this option is If the `FunctionGraph` has only one output and this option is
``True``, return the single output instead of a tuple with the output. ``True``, return the single output instead of a tuple with the output.
...@@ -719,8 +722,12 @@ def fgraph_to_python( ...@@ -719,8 +722,12 @@ def fgraph_to_python(
if storage_map is None: if storage_map is None:
storage_map = {} storage_map = {}
if not unique_name:
unique_name = unique_name_generator([fgraph_name]) unique_name = unique_name_generator([fgraph_name])
# make sure we plumb this through
kwargs["unique_name"] = unique_name
if global_env is None: if global_env is None:
global_env = {} global_env = {}
......
...@@ -22,6 +22,7 @@ from pytensor.tensor.type import matrices, matrix, scalar, vector ...@@ -22,6 +22,7 @@ from pytensor.tensor.type import matrices, matrix, scalar, vector
torch = pytest.importorskip("torch") torch = pytest.importorskip("torch")
torch_dispatch = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")
optimizer = RewriteDatabaseQuery( optimizer = RewriteDatabaseQuery(
...@@ -335,7 +336,7 @@ def test_pytorch_OpFromGraph(): ...@@ -335,7 +336,7 @@ def test_pytorch_OpFromGraph():
ofg_2 = OpFromGraph([x, y], [x * y, x - y]) ofg_2 = OpFromGraph([x, y], [x * y, x - y])
o1, o2 = ofg_2(y, z) o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2 out = ofg_1(x, o1) / o2
xv = np.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
...@@ -343,3 +344,33 @@ def test_pytorch_OpFromGraph(): ...@@ -343,3 +344,33 @@ def test_pytorch_OpFromGraph():
f = FunctionGraph([x, y, z], [out]) f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv]) compare_pytorch_and_py(f, [xv, yv, zv])
def test_pytorch_link_references():
import pytensor.link.utils as m
class BasicOp(Op):
def __init__(self):
super().__init__()
def make_node(self, *x):
return Apply(self, list(x), [xi.type() for xi in x])
def perform(self, *_):
raise RuntimeError("In perform")
@torch_dispatch.pytorch_funcify.register(BasicOp)
def fn(op, node, **kwargs):
def inner_fn(x):
assert "inner_fn" in dir(m), "not available during dispatch"
return x
return inner_fn
x = vector("x")
op = BasicOp()
out = op(x)
f = function([x], out, mode="PYTORCH")
f(torch.ones(3))
assert "inner_fn" not in dir(m), "function call reference leaked"
...@@ -29,7 +29,6 @@ class TestOp(Op): ...@@ -29,7 +29,6 @@ class TestOp(Op):
@basic.pytorch_funcify.register(TestOp) @basic.pytorch_funcify.register(TestOp)
def evaluate_test_op(op, **_): def evaluate_test_op(op, **_):
@torch.compiler.disable(recursive=False)
def func(a, b): def func(a, b):
op.call_shapes.extend(map(torch.Tensor.size, [a, b])) op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
return a @ b return a @ b
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论