Unverified 提交 4fa9bb87 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

PyTorch inline constants in dispatch to avoid graph breaks (#1118)

* Split and inverse * PyTorch inline constants in dispatch to avoid graph breaks
上级 17748b7d
......@@ -8,6 +8,7 @@ import torch.compiler
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
......@@ -19,6 +20,7 @@ from pytensor.tensor.basic import (
Eye,
Join,
MakeVector,
Split,
TensorFromScalar,
)
......@@ -120,14 +122,23 @@ def pytorch_funcify_arange(op, **kwargs):
@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
def pytorch_funcify_Join(op, node, **kwargs):
axis = node.inputs[0]
return torch.cat(tensors, dim=axis)
if isinstance(axis, Constant):
axis = int(axis.data)
return join
def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)
return join_constant_axis
else:
def join(axis, *tensors):
return torch.cat(tensors, dim=axis)
return join
@pytorch_funcify.register(Eye)
......@@ -172,7 +183,6 @@ def pytorch_funcify_IfElse(op, **kwargs):
@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)
# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
......@@ -185,3 +195,23 @@ def pytorch_funcify_TensorFromScalar(op, **kwargs):
return torch.as_tensor(x)
return tensorfromscalar
@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
x, dim, split_sizes = node.inputs
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
split_sizes = tuple(int(size) for size in split_sizes.data)
def split_constant_axis_and_sizes(x, *_):
return x.split(split_sizes, dim=dim)
return split_constant_axis_and_sizes
else:
def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())
return inner_fn
......@@ -5,12 +5,18 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
Invert,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus
@pytorch_funcify.register(Invert)
def pytorch_funcify_invert(op, node, **kwargs):
return torch.bitwise_not
@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
......
import torch
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
_, shape = node.inputs
return reshape
if isinstance(shape, Constant):
constant_shape = tuple(int(dim) for dim in shape.data)
def reshape_constant_shape(x, *_):
return torch.reshape(x, constant_shape)
return reshape_constant_shape
else:
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
return reshape
@pytorch_funcify.register(Shape)
......
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -23,7 +24,21 @@ def check_negative_steps(indices):
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
x, *idxs = node.inputs
if all(isinstance(idx, Constant) for idx in idxs):
# Use constant indices to avoid graph break
constant_indices = indices_from_subtensor(
[int(idx.data) for idx in idxs], idx_list
)
check_negative_steps(constant_indices)
def constant_index_subtensor(x, *_):
return x[constant_indices]
return constant_index_subtensor
# Fallback that will introduce a graph break
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
......
......@@ -37,6 +37,9 @@ class PytorchLinker(JITLinker):
def jit_compile(self, fn):
import torch
# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
from pytensor.link.pytorch.dispatch import pytorch_typify
class wrapper:
......
......@@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries():
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)
rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
(
0,
0,
rng.normal(size=20).astype(config.floatX),
[],
),
(
5,
0,
rng.normal(size=5).astype(config.floatX),
rng.multinomial(5, np.ones(5) / 5),
),
(
5,
0,
rng.normal(size=10).astype(config.floatX),
rng.multinomial(10, np.ones(5) / 5),
),
(
5,
-1,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(7, np.ones(5) / 5),
),
(
5,
-2,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(11, np.ones(5) / 5),
),
],
)
def test_Split(n_splits, axis, values, sizes):
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
s = pt.vector("s", dtype="int64")
g = pt.split(i, s, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)
compare_pytorch_and_py(g_fg, [values, sizes])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论