提交 69efc68b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle inplace rewrites correctly in dispatch of OpFromGraph and Scan

JAX needs no special handling because it excludes inplace rewrites.
上级 ad1af2ea
......@@ -5,6 +5,7 @@ import copyreg
import logging
import time
import warnings
from collections.abc import Sequence
from itertools import chain
from typing import TYPE_CHECKING
......@@ -168,6 +169,59 @@ class Supervisor(Feature):
raise InconsistencyError(f"Trying to destroy a protected variable: {r}")
def add_supervisor_to_fgraph(
fgraph: FunctionGraph,
input_specs: Sequence[SymbolicInput],
accept_inplace: bool = False,
) -> None:
"""Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used.
Parameters
----------
fgraph: FunctionGraph
The FunctionGraph to setup the Supervisor Feature in.
input_specs: Sequence of SymbolicInput
The input specifications for the FunctionGraph.
Inputs with the attribute `mutable=False` and which are not already destroyed by an inplace operation
(if `accept_inplace` is True) will be protected from inplace operations.
Otherwise, they will be allowed to be destroyed.
accept_inplace: bool
Whether to allow inplace operations to already be present in the graph.
Raises
------
TypeError
If inplace operations are not allowed and the graph already contains inplace operations.
"""
has_destroy_handler = hasattr(fgraph, "destroyers")
if not (has_destroy_handler and accept_inplace):
# Check if fgraph already contains destructive operations,
# in which case we need to add a DestroyHandler or raise an error
for node in fgraph.apply_nodes:
if node.op.destroy_map:
if not accept_inplace:
raise TypeError(
f"Graph must not contain inplace operations: {node}"
)
else:
has_destroy_handler = True
fgraph.attach_feature(DestroyHandler())
break
# Protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
)
)
)
def std_fgraph(
input_specs: list[SymbolicInput],
output_specs: list[SymbolicOutput],
......@@ -229,24 +283,8 @@ def std_fgraph(
found_updates.extend(map(SymbolicOutput, updates))
for node in fgraph.apply_nodes:
if node.op.destroy_map:
if not accept_inplace:
raise TypeError(f"Graph must not contain inplace operations: {node}")
else:
fgraph.attach_feature(DestroyHandler())
break
# We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
add_supervisor_to_fgraph(
fgraph=fgraph, input_specs=input_specs, accept_inplace=accept_inplace
)
# If named nodes are replaced, keep the name
......
......@@ -138,7 +138,11 @@ class AddDestroyHandler(GraphRewriter):
break
if not supervisor_added:
warnings.warn(
f"A Supervisor feature is missing from {fgraph}.",
(
f"A Supervisor feature is missing from {fgraph}.\n"
"This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.\n"
"A Supervisor feature can be added via `pytensor.compile.function.types.add_supervisor_to_fgraph`."
),
stacklevel=3,
)
......
import jax
import jax.numpy as jnp
from pytensor.compile.mode import JAX
from pytensor.compile.mode import JAX, get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan.op import Scan
......@@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs):
)
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
rewriter = (
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
)
rewriter(op.fgraph)
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
......
......@@ -16,9 +16,10 @@ from numba.core.errors import NumbaWarning, TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload
from pytensor import config
from pytensor import In, config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
......@@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
......
......@@ -4,7 +4,9 @@ import numpy as np
from numba import types
from numba.extending import overload
from pytensor.compile.mode import NUMBA
from pytensor import In
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
......@@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs):
# explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase
rewriter = (
op.mode_instance.including("numba")
get_mode(op.mode)
.including("numba")
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
rewriter(op.fgraph)
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
rewriter(fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
......
......@@ -5,8 +5,10 @@ import numpy as np
import torch
import torch.compiler
from pytensor import In
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
......@@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)
# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
PYTORCH.optimizer(fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
return fgraph_fn
......
......@@ -57,6 +57,7 @@ import pytensor.link.utils as link_utils
from pytensor import tensor as pt
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer
......@@ -834,8 +835,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
_ = self.prepare_fgraph(self.fgraph)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError(
"Inner-graphs must not contain in-place operations."
......@@ -1394,23 +1393,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
fgraph.update_mapping = update_mapping
from pytensor.compile.function.types import Supervisor
from pytensor.graph.destroyhandler import DestroyHandler
for node in fgraph.apply_nodes:
if node.op.destroy_map:
fgraph.attach_feature(DestroyHandler())
break
fgraph.attach_feature(
Supervisor(
inp
for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True)
if not (
getattr(spec, "mutable", None)
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
)
)
add_supervisor_to_fgraph(
fgraph=fgraph, input_specs=wrapped_inputs, accept_inplace=True
)
return wrapped_inputs, wrapped_outputs
......
......@@ -835,6 +835,39 @@ def test_OpFromGraph():
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
@pytest.mark.filterwarnings("error")
def test_ofg_inner_inplace():
x = pt.vector("x")
set0 = x[0].set(1) # SetSubtensor should not inplace on x
exp_x = pt.exp(x)
set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x
ofg0 = OpFromGraph([x], [set0])
ofg1 = OpFromGraph([x], [set1])
y, z = pt.vectors("y", "z")
fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA")
fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg0, OpFromGraph)
fn_set0 = fn_ofg0.fgraph.outputs[0]
assert fn_set0.owner.op.destroy_map == {}
fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op
assert isinstance(fn_ofg1, OpFromGraph)
fn_set1 = fn_ofg1.fgraph.outputs[0]
assert fn_set1.owner.op.destroy_map == {0: [0]}
x_test = np.array([0, 1, 1], dtype=config.floatX)
y_test = np.array([0, 1, 1], dtype=config.floatX)
res0, res1 = fn(x_test, y_test)
# Check inputs were not mutated
np.testing.assert_allclose(x_test, [0, 1, 1])
np.testing.assert_allclose(y_test, [0, 1, 1])
# Check outputs are correct
np.testing.assert_allclose(res0, [1, 1, 1])
np.testing.assert_allclose(res1, [1, np.e, np.e])
@pytest.mark.filterwarnings("error")
def test_cache_warning_suppressed():
x = pt.vector("x", shape=(5,), dtype="float64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论