Unverified 提交 4fa9bb87 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

PyTorch inline constants in dispatch to avoid graph breaks (#1118)

* Split and inverse * PyTorch inline constants in dispatch to avoid graph breaks
上级 17748b7d
...@@ -8,6 +8,7 @@ import torch.compiler ...@@ -8,6 +8,7 @@ import torch.compiler
from pytensor.compile import PYTORCH from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
...@@ -19,6 +20,7 @@ from pytensor.tensor.basic import ( ...@@ -19,6 +20,7 @@ from pytensor.tensor.basic import (
Eye, Eye,
Join, Join,
MakeVector, MakeVector,
Split,
TensorFromScalar, TensorFromScalar,
) )
...@@ -120,11 +122,20 @@ def pytorch_funcify_arange(op, **kwargs): ...@@ -120,11 +122,20 @@ def pytorch_funcify_arange(op, **kwargs):
@pytorch_funcify.register(Join) @pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs): def pytorch_funcify_Join(op, node, **kwargs):
def join(axis, *tensors): axis = node.inputs[0]
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors] if isinstance(axis, Constant):
axis = int(axis.data)
def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)
return join_constant_axis
else:
def join(axis, *tensors):
return torch.cat(tensors, dim=axis) return torch.cat(tensors, dim=axis)
return join return join
...@@ -172,7 +183,6 @@ def pytorch_funcify_IfElse(op, **kwargs): ...@@ -172,7 +183,6 @@ def pytorch_funcify_IfElse(op, **kwargs):
@pytorch_funcify.register(OpFromGraph) @pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs): def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None) kwargs.pop("storage_map", None)
# Apply inner rewrites # Apply inner rewrites
PYTORCH.optimizer(op.fgraph) PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
...@@ -185,3 +195,23 @@ def pytorch_funcify_TensorFromScalar(op, **kwargs): ...@@ -185,3 +195,23 @@ def pytorch_funcify_TensorFromScalar(op, **kwargs):
return torch.as_tensor(x) return torch.as_tensor(x)
return tensorfromscalar return tensorfromscalar
@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
x, dim, split_sizes = node.inputs
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
split_sizes = tuple(int(size) for size in split_sizes.data)
def split_constant_axis_and_sizes(x, *_):
return x.split(split_sizes, dim=dim)
return split_constant_axis_and_sizes
else:
def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())
return inner_fn
...@@ -5,12 +5,18 @@ import torch ...@@ -5,12 +5,18 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
Cast, Cast,
Invert,
ScalarOp, ScalarOp,
) )
from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus from pytensor.scalar.math import Softplus
@pytorch_funcify.register(Invert)
def pytorch_funcify_invert(op, node, **kwargs):
return torch.bitwise_not
@pytorch_funcify.register(ScalarOp) @pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs): def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op. """Return pytorch function that implements the same computation as the Scalar Op.
......
import torch import torch
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
@pytorch_funcify.register(Reshape) @pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs): def pytorch_funcify_Reshape(op, node, **kwargs):
_, shape = node.inputs
if isinstance(shape, Constant):
constant_shape = tuple(int(dim) for dim in shape.data)
def reshape_constant_shape(x, *_):
return torch.reshape(x, constant_shape)
return reshape_constant_shape
else:
def reshape(x, shape): def reshape(x, shape):
return torch.reshape(x, tuple(shape)) return torch.reshape(x, tuple(shape))
......
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -23,7 +24,21 @@ def check_negative_steps(indices): ...@@ -23,7 +24,21 @@ def check_negative_steps(indices):
@pytorch_funcify.register(Subtensor) @pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs): def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list idx_list = op.idx_list
x, *idxs = node.inputs
if all(isinstance(idx, Constant) for idx in idxs):
# Use constant indices to avoid graph break
constant_indices = indices_from_subtensor(
[int(idx.data) for idx in idxs], idx_list
)
check_negative_steps(constant_indices)
def constant_index_subtensor(x, *_):
return x[constant_indices]
return constant_index_subtensor
# Fallback that will introduce a graph break
def subtensor(x, *flattened_indices): def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list) indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices) check_negative_steps(indices)
......
...@@ -37,6 +37,9 @@ class PytorchLinker(JITLinker): ...@@ -37,6 +37,9 @@ class PytorchLinker(JITLinker):
def jit_compile(self, fn): def jit_compile(self, fn):
import torch import torch
# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
from pytensor.link.pytorch.dispatch import pytorch_typify from pytensor.link.pytorch.dispatch import pytorch_typify
class wrapper: class wrapper:
......
...@@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries(): ...@@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries():
compare_pytorch_and_py( compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
) )
rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
(
0,
0,
rng.normal(size=20).astype(config.floatX),
[],
),
(
5,
0,
rng.normal(size=5).astype(config.floatX),
rng.multinomial(5, np.ones(5) / 5),
),
(
5,
0,
rng.normal(size=10).astype(config.floatX),
rng.multinomial(10, np.ones(5) / 5),
),
(
5,
-1,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(7, np.ones(5) / 5),
),
(
5,
-2,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(11, np.ones(5) / 5),
),
],
)
def test_Split(n_splits, axis, values, sizes):
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
s = pt.vector("s", dtype="int64")
g = pt.split(i, s, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)
compare_pytorch_and_py(g_fg, [values, sizes])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论