提交 7300a687 authored 作者: Ian Schweer's avatar Ian Schweer 提交者: Ricardo Vieira

Track generated torch files for torch compiler

上级 4b41e092
......@@ -54,14 +54,16 @@ def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="pytorch_funcified_fgraph",
conversion_func=pytorch_funcify,
**kwargs,
):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
pytorch_funcify,
conversion_func,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
**built_kwargs,
)
......@@ -173,11 +175,8 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)
return fgraph_fn
@pytorch_funcify.register(TensorFromScalar)
......
import torch
import torch.compiler
from pytensor.graph import FunctionGraph
from pytensor.link.pytorch.dispatch import pytorch_funcify
......@@ -11,12 +10,13 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
batched_dims = op.batch_ndim(node)
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1)
inner_func = pytorch_funcify(
core_fgraph, squeeze_output=len(node.outputs) == 1, **kwargs
)
for _ in range(batched_dims):
inner_func = torch.vmap(inner_func)
@torch.compiler.disable(recursive=False)
def batcher(*inputs):
op._check_runtime_broadcast(node, inputs)
# broadcast on batched_dims
......
import copy
from typing import Any
from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator
class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify
......@@ -18,14 +24,68 @@ class PytorchLinker(JITLinker):
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
# We want to have globally unique names
# across the entire pytensor graph, not
# just the subgraph
generator = unique_name_generator(["torch_linker"])
# Ensure that torch is aware of the generated
# code so we can compile without graph breaks
def conversion_func_register(*args, **kwargs):
functor = pytorch_funcify(*args, **kwargs)
name = kwargs["unique_name"](functor)
self.gen_functors.append((f"_{name}", functor))
return functor
built_kwargs = {
"unique_name": generator,
"conversion_func": conversion_func_register,
**kwargs,
}
return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
)
def jit_compile(self, fn):
import torch
return torch.compile(fn)
class wrapper:
"""
Pytorch would fail compiling our method when trying
to resolve some of the methods returned from dispatch
calls. We want to be careful to not leak the methods,
so this class just holds them and provisions the expected
location accordingly
https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
"""
def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = copy.copy(gen_functors)
def __call__(self, *args, **kwargs):
import pytensor.link.utils
# set attrs
for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn)
res = self.fn(*args, **kwargs)
# unset attrs
for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])
return res
def __del__(self):
del self.gen_functors
res = wrapper(fn, self.gen_functors)
self.gen_functors = []
return res
def create_thunk_inputs(self, storage_map):
thunk_inputs = []
......
......@@ -675,6 +675,7 @@ def fgraph_to_python(
local_env: dict[Any, Any] | None = None,
get_name_for_object: Callable[[Any], str] = get_name_for_object,
squeeze_output: bool = False,
unique_name: Callable | None = None,
**kwargs,
) -> Callable:
"""Convert a `FunctionGraph` into a regular Python function.
......@@ -706,6 +707,8 @@ def fgraph_to_python(
get_name_for_object
A function used to provide names for the objects referenced within the
generated function.
unique_name
A function to make random function names for generated code
squeeze_output
If the `FunctionGraph` has only one output and this option is
``True``, return the single output instead of a tuple with the output.
......@@ -719,8 +722,12 @@ def fgraph_to_python(
if storage_map is None:
storage_map = {}
if not unique_name:
unique_name = unique_name_generator([fgraph_name])
# make sure we plumb this through
kwargs["unique_name"] = unique_name
if global_env is None:
global_env = {}
......
......@@ -22,6 +22,7 @@ from pytensor.tensor.type import matrices, matrix, scalar, vector
torch = pytest.importorskip("torch")
torch_dispatch = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")
optimizer = RewriteDatabaseQuery(
......@@ -335,7 +336,7 @@ def test_pytorch_OpFromGraph():
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
out = ofg_1(x, o1) / o2
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
......@@ -343,3 +344,33 @@ def test_pytorch_OpFromGraph():
f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])
def test_pytorch_link_references():
import pytensor.link.utils as m
class BasicOp(Op):
def __init__(self):
super().__init__()
def make_node(self, *x):
return Apply(self, list(x), [xi.type() for xi in x])
def perform(self, *_):
raise RuntimeError("In perform")
@torch_dispatch.pytorch_funcify.register(BasicOp)
def fn(op, node, **kwargs):
def inner_fn(x):
assert "inner_fn" in dir(m), "not available during dispatch"
return x
return inner_fn
x = vector("x")
op = BasicOp()
out = op(x)
f = function([x], out, mode="PYTORCH")
f(torch.ones(3))
assert "inner_fn" not in dir(m), "function call reference leaked"
......@@ -29,7 +29,6 @@ class TestOp(Op):
@basic.pytorch_funcify.register(TestOp)
def evaluate_test_op(op, **_):
@torch.compiler.disable(recursive=False)
def func(a, b):
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
return a @ b
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论