提交 69efc68b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle inplace rewrites correctly in dispatch of OpFromGraph and Scan

JAX needs no special handling because it excludes inplace rewrites.
上级 ad1af2ea
...@@ -5,6 +5,7 @@ import copyreg ...@@ -5,6 +5,7 @@ import copyreg
import logging import logging
import time import time
import warnings import warnings
from collections.abc import Sequence
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
...@@ -168,6 +169,59 @@ class Supervisor(Feature): ...@@ -168,6 +169,59 @@ class Supervisor(Feature):
raise InconsistencyError(f"Trying to destroy a protected variable: {r}") raise InconsistencyError(f"Trying to destroy a protected variable: {r}")
def add_supervisor_to_fgraph(
fgraph: FunctionGraph,
input_specs: Sequence[SymbolicInput],
accept_inplace: bool = False,
) -> None:
"""Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used.
Parameters
----------
fgraph: FunctionGraph
The FunctionGraph to setup the Supervisor Feature in.
input_specs: Sequence of SymbolicInput
The input specifications for the FunctionGraph.
Inputs with the attribute `mutable=False` and which are not already destroyed by an inplace operation
(if `accept_inplace` is True) will be protected from inplace operations.
Otherwise, they will be allowed to be destroyed.
accept_inplace: bool
Whether to allow inplace operations to already be present in the graph.
Raises
------
TypeError
If inplace operations are not allowed and the graph already contains inplace operations.
"""
has_destroy_handler = hasattr(fgraph, "destroyers")
if not (has_destroy_handler and accept_inplace):
# Check if fgraph already contains destructive operations,
# in which case we need to add a DestroyHandler or raise an error
for node in fgraph.apply_nodes:
if node.op.destroy_map:
if not accept_inplace:
raise TypeError(
f"Graph must not contain inplace operations: {node}"
)
else:
has_destroy_handler = True
fgraph.attach_feature(DestroyHandler())
break
# Protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
)
)
)
def std_fgraph( def std_fgraph(
input_specs: list[SymbolicInput], input_specs: list[SymbolicInput],
output_specs: list[SymbolicOutput], output_specs: list[SymbolicOutput],
...@@ -229,24 +283,8 @@ def std_fgraph( ...@@ -229,24 +283,8 @@ def std_fgraph(
found_updates.extend(map(SymbolicOutput, updates)) found_updates.extend(map(SymbolicOutput, updates))
for node in fgraph.apply_nodes: add_supervisor_to_fgraph(
if node.op.destroy_map: fgraph=fgraph, input_specs=input_specs, accept_inplace=accept_inplace
if not accept_inplace:
raise TypeError(f"Graph must not contain inplace operations: {node}")
else:
fgraph.attach_feature(DestroyHandler())
break
# We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
) )
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
......
...@@ -138,7 +138,11 @@ class AddDestroyHandler(GraphRewriter): ...@@ -138,7 +138,11 @@ class AddDestroyHandler(GraphRewriter):
break break
if not supervisor_added: if not supervisor_added:
warnings.warn( warnings.warn(
f"A Supervisor feature is missing from {fgraph}.", (
f"A Supervisor feature is missing from {fgraph}.\n"
"This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.\n"
"A Supervisor feature can be added via `pytensor.compile.function.types.add_supervisor_to_fgraph`."
),
stacklevel=3, stacklevel=3,
) )
......
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from pytensor.compile.mode import JAX from pytensor.compile.mode import JAX, get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs):
) )
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer rewriter = (
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
)
rewriter(op.fgraph) rewriter(op.fgraph)
scan_inner_func = jax_funcify(op.fgraph, **kwargs) scan_inner_func = jax_funcify(op.fgraph, **kwargs)
......
...@@ -16,9 +16,10 @@ from numba.core.errors import NumbaWarning, TypingError ...@@ -16,9 +16,10 @@ from numba.core.errors import NumbaWarning, TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload from numba.extending import box, overload
from pytensor import config from pytensor import In, config
from pytensor.compile import NUMBA from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): ...@@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that # TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph? # explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase # The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph) fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1: if len(op.fgraph.outputs) == 1:
......
...@@ -4,7 +4,9 @@ import numpy as np ...@@ -4,7 +4,9 @@ import numpy as np
from numba import types from numba import types
from numba.extending import overload from numba.extending import overload
from pytensor.compile.mode import NUMBA from pytensor import In
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_arg_string, create_arg_string,
...@@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs):
# explicitly triggers the optimization of the inner graphs of Scan? # explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase # The C-code defers it to the make_thunk phase
rewriter = ( rewriter = (
op.mode_instance.including("numba") get_mode(op.mode)
.including("numba")
.excluding(*NUMBA._optimizer.exclude) .excluding(*NUMBA._optimizer.exclude)
.optimizer .optimizer
) )
rewriter(op.fgraph) fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
rewriter(fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
......
...@@ -5,8 +5,10 @@ import numpy as np ...@@ -5,8 +5,10 @@ import numpy as np
import torch import torch
import torch.compiler import torch.compiler
from pytensor import In
from pytensor.compile import PYTORCH from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs): ...@@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None) kwargs.pop("storage_map", None)
# Apply inner rewrites # Apply inner rewrites
PYTORCH.optimizer(op.fgraph) PYTORCH.optimizer(op.fgraph)
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
PYTORCH.optimizer(fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
return fgraph_fn return fgraph_fn
......
...@@ -57,6 +57,7 @@ import pytensor.link.utils as link_utils ...@@ -57,6 +57,7 @@ import pytensor.link.utils as link_utils
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
from pytensor.compile.function.pfunc import pfunc from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.io import In, Out from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_mode from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer from pytensor.compile.profiling import register_profiler_printer
...@@ -834,8 +835,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -834,8 +835,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outer_inputs = info.n_outer_inputs self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs self.n_outer_outputs = info.n_outer_outputs
_ = self.prepare_fgraph(self.fgraph)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes): if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError( raise InconsistencyError(
"Inner-graphs must not contain in-place operations." "Inner-graphs must not contain in-place operations."
...@@ -1394,23 +1393,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1394,23 +1393,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
fgraph.update_mapping = update_mapping fgraph.update_mapping = update_mapping
from pytensor.compile.function.types import Supervisor add_supervisor_to_fgraph(
from pytensor.graph.destroyhandler import DestroyHandler fgraph=fgraph, input_specs=wrapped_inputs, accept_inplace=True
for node in fgraph.apply_nodes:
if node.op.destroy_map:
fgraph.attach_feature(DestroyHandler())
break
fgraph.attach_feature(
Supervisor(
inp
for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True)
if not (
getattr(spec, "mutable", None)
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
)
)
) )
return wrapped_inputs, wrapped_outputs return wrapped_inputs, wrapped_outputs
......
...@@ -835,6 +835,39 @@ def test_OpFromGraph(): ...@@ -835,6 +835,39 @@ def test_OpFromGraph():
compare_numba_and_py([x, y, z], [out], [xv, yv, zv]) compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
@pytest.mark.filterwarnings("error")
def test_ofg_inner_inplace():
x = pt.vector("x")
set0 = x[0].set(1) # SetSubtensor should not inplace on x
exp_x = pt.exp(x)
set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x
ofg0 = OpFromGraph([x], [set0])
ofg1 = OpFromGraph([x], [set1])
y, z = pt.vectors("y", "z")
fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA")
fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg0, OpFromGraph)
fn_set0 = fn_ofg0.fgraph.outputs[0]
assert fn_set0.owner.op.destroy_map == {}
fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op
assert isinstance(fn_ofg1, OpFromGraph)
fn_set1 = fn_ofg1.fgraph.outputs[0]
assert fn_set1.owner.op.destroy_map == {0: [0]}
x_test = np.array([0, 1, 1], dtype=config.floatX)
y_test = np.array([0, 1, 1], dtype=config.floatX)
res0, res1 = fn(x_test, y_test)
# Check inputs were not mutated
np.testing.assert_allclose(x_test, [0, 1, 1])
np.testing.assert_allclose(y_test, [0, 1, 1])
# Check outputs are correct
np.testing.assert_allclose(res0, [1, 1, 1])
np.testing.assert_allclose(res1, [1, np.e, np.e])
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
def test_cache_warning_suppressed(): def test_cache_warning_suppressed():
x = pt.vector("x", shape=(5,), dtype="float64") x = pt.vector("x", shape=(5,), dtype="float64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论