提交 d9e8728a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not skip validation between consecutive Elemwise inplace replacements

上级 7d091be3
...@@ -7,7 +7,6 @@ and inplace operations. ...@@ -7,7 +7,6 @@ and inplace operations.
import itertools import itertools
from collections import deque from collections import deque
import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper from pytensor.graph.features import AlreadyThere, Bookkeeper
...@@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler): ...@@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler):
return droot, impact, root_destroyer return droot, impact, root_destroyer
def fast_inplace_check(fgraph, inputs): def inplace_candidates(fgraph, inputs, protected_inputs=None):
""" """
Return the variables in inputs that are possible candidate for as inputs of Return the variables in inputs that are possible candidate for as inputs of
inplace operation. inplace operation.
...@@ -234,22 +233,28 @@ def fast_inplace_check(fgraph, inputs): ...@@ -234,22 +233,28 @@ def fast_inplace_check(fgraph, inputs):
Inputs Variable that you want to use as inplace destination. Inputs Variable that you want to use as inplace destination.
""" """
Supervisor = pytensor.compile.function.types.Supervisor if protected_inputs is None:
protected_inputs = list( from pytensor.compile.function.types import Supervisor
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor) protected_inputs = set(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
protected_inputs.update(fgraph.outputs)
has_destroyers = fgraph.has_destroyers
return [
inp
# Remove duplicates, while preserving order by using dict.fromkeys
for inp in dict.fromkeys(inputs)
if (
not isinstance(inp, Constant)
and inp not in protected_inputs
and not has_destroyers([inp])
) )
)
protected_inputs.extend(fgraph.outputs)
inputs = [
i
for i in inputs
if not isinstance(i, Constant)
and not fgraph.has_destroyers([i])
and i not in protected_inputs
] ]
return inputs
class DestroyHandler(Bookkeeper): class DestroyHandler(Bookkeeper):
......
import itertools
from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
...@@ -274,25 +272,19 @@ def blockwise_inplace(fgraph, node): ...@@ -274,25 +272,19 @@ def blockwise_inplace(fgraph, node):
batch_ndim = blockwise_op.batch_ndim(node) batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim] out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
protected_inputs = [ inputs = node.inputs
f.protected for f in fgraph._features if isinstance(f, Supervisor) candidate_inputs = set(
] inplace_candidates(
protected_inputs = list(itertools.chain.from_iterable(protected_inputs)) fgraph,
protected_inputs.extend(fgraph.outputs) [
allowed_inplace_inputs = [ inp
idx for inp in inputs
for idx, inp in enumerate(node.inputs) if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
if ],
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
) )
)
allowed_inplace_inputs = [
i for i, inp in enumerate(inputs) if inp in candidate_inputs
] ]
if not allowed_inplace_inputs: if not allowed_inplace_inputs:
......
import itertools import itertools
import operator import operator
import sys import sys
from collections import Counter, defaultdict, deque from collections import defaultdict, deque
from collections.abc import Generator from collections.abc import Generator
from functools import cache, reduce from functools import cache, reduce
from typing import TypeVar from typing import TypeVar
from warnings import warn from warnings import warn
import pytensor
import pytensor.scalar.basic as ps import pytensor.scalar.basic as ps
from pytensor import clone_replace, compile from pytensor import clone_replace, compile
from pytensor.compile.function.types import Supervisor
from pytensor.compile.mode import get_target_language from pytensor.compile.mode import get_target_language
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort from pytensor.graph.basic import Apply, Variable, ancestors
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import Output from pytensor.graph.fg import Output
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
...@@ -43,7 +44,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -43,7 +44,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
) )
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant from pytensor.tensor.variable import TensorConstant, TensorVariable
class InplaceElemwiseOptimizer(GraphRewriter): class InplaceElemwiseOptimizer(GraphRewriter):
...@@ -51,31 +52,9 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -51,31 +52,9 @@ class InplaceElemwiseOptimizer(GraphRewriter):
This is parameterized so that it works for `Elemwise` `Op`\s. This is parameterized so that it works for `Elemwise` `Op`\s.
""" """
def __init__(self, OP):
self.op = OP
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
from pytensor.graph.destroyhandler import DestroyHandler
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, prof["opt"].op, file=stream)
for k in [
"node_before",
"nb_call_replace",
"nb_call_validate",
"nb_inconsistent",
]:
print(blanc, k, prof[k], file=stream)
ndim = prof["ndim"]
if ndim:
print(blanc, "ndim", "nb", file=stream)
for n in sorted(ndim):
print(blanc, n, ndim[n], file=stream)
def apply(self, fgraph): def apply(self, fgraph):
r""" r"""
...@@ -92,8 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -92,8 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
""" """
# We should not validate too often as this takes too much time to # We should not validate too often as this takes too much time to execute!
# execute!
# It is the _dfs_toposort() fct in pytensor/graph/destroyhandler.py # It is the _dfs_toposort() fct in pytensor/graph/destroyhandler.py
# that takes so much time. # that takes so much time.
# Should we try to use another lib that does toposort? # Should we try to use another lib that does toposort?
...@@ -111,244 +89,199 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -111,244 +89,199 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# Then I think it is the [io_?]toposort (need to validate) so check if # Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there. # the solution is also applicable there.
# We execute `validate` after this number of change. # 2025: The above comment is not specific to Elemwise, if we have concerns about this approach, we should
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
# We can consider restricting inputs with static shapes that are large enough.
def create_inplace_node(node, inplace_pattern):
op = node.op
scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scalar_op = type(scalar_op)(
ps.transfer_type(
*[
inplace_pattern.get(i, None)
for i in range(len(node.outputs))
]
)
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
if config.tensor__insert_inplace_optimizer_validate_nb != -1:
warn(
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.",
FutureWarning,
)
prof = { prof = {
"opt": self, "opt": self,
"node_before": len(fgraph.apply_nodes), "node_before": len(fgraph.apply_nodes),
"nb_call_replace": 0, "nb_eager_inconsistent": 0,
"nb_call_validate": 0,
"nb_inconsistent": 0, "nb_inconsistent": 0,
"ndim": Counter(), "nb_replaced": 0,
} }
large_graph = len(fgraph.apply_nodes) > 500
check_each_change = config.tensor__insert_inplace_optimizer_validate_nb protected_inputs = set(
if check_each_change == -1:
if len(fgraph.apply_nodes) > 500:
check_each_change = 10
else:
check_each_change = 1
nb_change_no_validate = 0
chk = fgraph.checkpoint()
if fgraph.update_mapping:
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
else:
update_outs = []
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = list(
itertools.chain.from_iterable( itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor) f.protected for f in fgraph._features if isinstance(f, Supervisor)
) )
) )
protected_inputs.extend(fgraph.outputs) protected_inputs.update(fgraph.outputs)
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)): root_destroyer = fgraph.destroy_handler.root_destroyer
op = node.op
if not isinstance(op, self.op): update_mapping = fgraph.update_mapping or {}
continue op_updates: dict[TensorVariable, TensorVariable] = {
# If big graph and the outputs are scalar, do not make it out: fgraph.inputs[update_mapping[out_idx]]
# inplace. for out_idx, out in enumerate(fgraph.outputs)
if ( if (
check_each_change != 1 out_idx in update_mapping
and and out.owner
# If multiple outputs, they must all have the same size, and isinstance(out.owner.op, Elemwise)
# so only check the first. )
getattr(node.outputs[0].type, "ndim", -1) == 0 }
): set_op_updates = set(op_updates.keys())
for node in fgraph.toposort():
if not isinstance(node.op, Elemwise) or node.op.destroy_map:
continue continue
if op.inplace_pattern: # If big graph and the outputs are scalar, do not make it inplace.
# Maybe this isn't needed anymore, but I don't want to if large_graph and all(node.outputs[0].type.broadcastable):
# rish regression now. This case only happen if the continue
# original node add already some inplace patter and we
# still try to add more pattern.
baseline = op.inplace_pattern candidate_inputs = [
candidate_outputs = [ (node.inputs.index(inp), inp)
i for i in range(len(node.outputs)) if i not in baseline for inp in inplace_candidates(
] fgraph,
# node inputs that are Constant, already destroyed, node.inputs,
# or fgraph protected inputs and fgraph outputs can't be used as protected_inputs=protected_inputs,
# inplace target. )
# Remove here as faster. ]
candidate_inputs = [ if not candidate_inputs:
i return []
for i in range(len(node.inputs))
if i not in baseline.values() candidate_pairs = [
and not isinstance(node.inputs[i], Constant) ((o, out), (i, inp))
# the next line should not be costly most of the time. for o, out in enumerate(node.outputs)
and not fgraph.has_destroyers([node.inputs[i]]) for i, inp in candidate_inputs
and node.inputs[i] not in protected_inputs if inp.type == out.type
] ]
else:
baseline = [] if not candidate_pairs:
candidate_outputs = range(len(node.outputs)) continue
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs = [
i
for i in range(len(node.inputs))
if not isinstance(node.inputs[i], Constant)
and not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
verbose = False sorted_candidate_pairs = candidate_pairs
if op_updates and (node_updates := set(node.outputs) & set_op_updates):
raised_warning = not verbose # If the fgraph has updates, we try to prioritize in-placing on the pairs that correspond to the update
direct_update_pairs = []
for candidate_output in candidate_outputs: indirect_update_pairs = []
# If the output of the node can be established as an update other_update_pairs = []
# output of the fgraph, visit the candidate_inputs in an order for pair in candidate_pairs:
# that will improve the chances of making the node operate ((o, out), (i, inp)) = pair
# inplace on the input it's meant to update if out in node_updates:
candidate_out_var = node.outputs[candidate_output] direct_update_inp = op_updates[out]
sorted_candidate_inputs = candidate_inputs if direct_update_inp is inp:
# This pair is the whole graph update
if candidate_out_var in update_outs: direct_update_pairs.append(pair)
# The candidate output is an update. Sort the continue
# variables in candidate_inputs in the following order: elif (inp_node := inp.owner) is not None and any(
# - Vars corresponding to the actual updated input root_destroyer.get(up_inp, None) is inp_node
# (best case scenario is for the node that procudes for up_inp in op_updates.values()
# an update to operate inplace on the variable to
# update)
# - Vars computed inplace on the updates input (second
# best scenario if for the node to work inplace on
# a variable obtained by a chain of inplace on the
# variable to update. In some cases, this will be
# equivalent to operating inplace on the variable to
# update)
# - Remaining variables
updated_inputs = []
for i, f_out in enumerate(fgraph.outputs):
if f_out is candidate_out_var and i in fgraph.update_mapping:
updated_inp_idx = fgraph.update_mapping[i]
updated_inputs.append(fgraph.inputs[updated_inp_idx])
updated_vars = []
vars_from_inplace = []
other_vars = []
for inp_idx in candidate_inputs:
inp = node.inputs[inp_idx]
if inp in updated_inputs:
# the candidate input is the actual updated input
updated_vars.append(inp_idx)
elif (
hasattr(fgraph, "destroy_handler")
and inp.owner
and any(
fgraph.destroy_handler.root_destroyer.get(up_inp, None)
is inp.owner
for up_inp in updated_inputs
)
): ):
# the candidate input is a variable computed # This pair connects to an updated input
# inplace on the updated input via a sequence of indirect_update_pairs.append(pair)
# one or more inplace operations continue
vars_from_inplace.append(inp_idx) other_update_pairs.append(pair)
else:
other_vars.append(inp_idx)
sorted_candidate_inputs = ( sorted_candidate_pairs = (
updated_vars + vars_from_inplace + other_vars direct_update_pairs + indirect_update_pairs + other_update_pairs
) )
for candidate_input in sorted_candidate_inputs: # Try in-placing all outputs at once
# remove inputs that don't have the same dtype as the output tried_inputs = set()
if ( inplace_pattern = {}
node.inputs[candidate_input].type for (o, _), (i, _) in sorted_candidate_pairs:
!= node.outputs[candidate_output].type if o not in inplace_pattern and i not in tried_inputs:
): inplace_pattern[o] = [i]
continue tried_inputs.add(i)
inplace_node = create_inplace_node(node, inplace_pattern)
if inplace_node.op.destroy_map == inplace_pattern:
replacements = tuple(zip(node.outputs, inplace_node.outputs))
try:
fgraph.replace_all_validate(
replacements, reason="inplace_elemwise_optimizer"
)
except InconsistencyError:
prof["nb_eager_inconsistent"] += 1
else:
prof["nb_replaced"] += 1
continue
inplace_pattern = dict(baseline) # If it fails or doesn't match the desired inplace pattern, try one output/input at a time
inplace_pattern[candidate_output] = candidate_input tried_inputs = set()
try: inplace_pattern = {}
if hasattr(op.scalar_op, "make_new_inplace"): replaced = False
new_scal = op.scalar_op.make_new_inplace( for (o, _), (i, _) in sorted_candidate_pairs:
ps.transfer_type( if o not in inplace_pattern and i not in tried_inputs:
*[ inplace_pattern[o] = [i]
inplace_pattern.get(i, o.dtype) tried_inputs.add(i)
for i, o in enumerate(node.outputs)
] inplace_node = create_inplace_node(node, inplace_pattern)
) if inplace_node.op.destroy_map != inplace_pattern:
) # This Op can't respect this partial inplace pattern,
else: # We assume it can't support any other cases
new_scal = op.scalar_op.__class__( break
ps.transfer_type( else:
*[ replacements = tuple(zip(node.outputs, inplace_node.outputs))
inplace_pattern.get(i, None) try:
for i in range(len(node.outputs)) fgraph.replace_all_validate(
] replacements, reason="inplace_elemwise_optimizer"
)
) )
new_outputs = self.op(new_scal, inplace_pattern)( node = inplace_node
*node.inputs, return_list=True replaced = True
) except InconsistencyError:
new_node = new_outputs[0].owner prof["nb_inconsistent"] += 1
# The input, not the output caused inconsistencies
inplace_pattern.pop(o)
prof["nb_replaced"] += replaced
for r, new_r in zip(node.outputs, new_outputs, strict=True):
prof["nb_call_replace"] += 1
fgraph.replace(
r, new_r, reason="inplace_elemwise_optimizer"
)
nb_change_no_validate += 1
prof["ndim"][candidate_out_var.ndim] += 1
if nb_change_no_validate >= check_each_change:
prof["nb_call_validate"] += 1
fgraph.validate()
chk = fgraph.checkpoint()
nb_change_no_validate = 0
except (ValueError, InconsistencyError) as e:
prof["nb_inconsistent"] += 1
if check_each_change != 1 and not raised_warning:
print( # noqa: T201
(
"Some inplace rewriting was not "
"performed due to an unexpected error:"
),
file=sys.stderr,
)
print(e, file=sys.stderr) # noqa: T201
raised_warning = True
fgraph.revert(chk)
continue
candidate_inputs.remove(candidate_input)
node = new_node
baseline = inplace_pattern
break
if nb_change_no_validate > 0:
try:
fgraph.validate()
except Exception:
if not raised_warning:
print( # noqa: T201
(
"Some inplace rewriting was not "
"performed due to an unexpected error"
),
file=sys.stderr,
)
fgraph.revert(chk)
return prof return prof
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, file=stream)
for k in [
"node_before",
"nb_eager_inconsistent",
"nb_inconsistent",
"nb_replaced",
]:
print(blanc, k, prof[k], file=stream)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print( print(
f"{' ' * level}{self.__class__.__name__} ({self.op})", f"{' ' * level}{self.__class__.__name__}",
file=stream, file=stream,
) )
return inplace_elemwise_optimizer
inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise)
compile.optdb.register( compile.optdb.register(
"inplace_elemwise_opt", "inplace_elemwise",
inplace_elemwise_optimizer, InplaceElemwiseOptimizer(),
"inplace_opt", # for historic reason "inplace_elemwise_opt", # for historic reason
"inplace_elemwise_optimizer", "inplace_elemwise_optimizer",
"fast_run", "fast_run",
"inplace", "inplace",
......
...@@ -8,6 +8,7 @@ from pytensor import In, shared ...@@ -8,6 +8,7 @@ from pytensor import In, shared
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad from pytensor.gradient import grad
...@@ -1529,3 +1530,31 @@ def test_constant_fold_branches_add_mul(op): ...@@ -1529,3 +1530,31 @@ def test_constant_fold_branches_add_mul(op):
new_out = rewrite_graph(out, include=("add_mul_fusion",)) new_out = rewrite_graph(out, include=("add_mul_fusion",))
assert len(new_out.owner.inputs) == 3 assert len(new_out.owner.inputs) == 3
assert equal_computations([new_out], [op(py_op(a, b), c, x)]) assert equal_computations([new_out], [op(py_op(a, b), c, x)])
def test_InplaceElemwiseOptimizer_bug():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1420
# This graph fails if InplaceElemwiseOptimizer were to try to skip `fgraph.validate`
# in between two invalid inplace rewrites.
z = pt.matrix("z")
z1 = ps.float64("z1")
z2 = ps.float64("z2")
out1, out2 = Elemwise(ps.Composite([z1, z2], [z1 + z2, z2 - z1]))(z[1:], z[:-1])
out = pt.exp(z[1:-1]).sum() + out1.sum() + out2.sum()
# Add 500 unrelated nodes to trigger the old special behavior
irrelevant_outs = [pt.specify_shape(z, (4, 4)) for _ in range(500)]
fgraph = FunctionGraph(inputs=[z], outputs=[out, *irrelevant_outs], clone=False)
add_supervisor_to_fgraph(fgraph, [In(z)])
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",))
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns(
FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论