提交 d9e8728a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not skip validation between consecutive Elemwise inplace replacements

上级 7d091be3
......@@ -7,7 +7,6 @@ and inplace operations.
import itertools
from collections import deque
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper
......@@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler):
return droot, impact, root_destroyer
def fast_inplace_check(fgraph, inputs):
def inplace_candidates(fgraph, inputs, protected_inputs=None):
"""
Return the variables in inputs that are possible candidate for as inputs of
inplace operation.
......@@ -234,22 +233,28 @@ def fast_inplace_check(fgraph, inputs):
Inputs Variable that you want to use as inplace destination.
"""
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = list(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
if protected_inputs is None:
from pytensor.compile.function.types import Supervisor
protected_inputs = set(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
protected_inputs.update(fgraph.outputs)
has_destroyers = fgraph.has_destroyers
return [
inp
# Remove duplicates, while preserving order by using dict.fromkeys
for inp in dict.fromkeys(inputs)
if (
not isinstance(inp, Constant)
and inp not in protected_inputs
and not has_destroyers([inp])
)
)
protected_inputs.extend(fgraph.outputs)
inputs = [
i
for i in inputs
if not isinstance(i, Constant)
and not fgraph.has_destroyers([i])
and i not in protected_inputs
]
return inputs
class DestroyHandler(Bookkeeper):
......
import itertools
from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
......@@ -274,25 +272,19 @@ def blockwise_inplace(fgraph, node):
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
inputs = node.inputs
candidate_inputs = set(
inplace_candidates(
fgraph,
[
inp
for inp in inputs
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
],
)
)
allowed_inplace_inputs = [
i for i, inp in enumerate(inputs) if inp in candidate_inputs
]
if not allowed_inplace_inputs:
......
import itertools
import operator
import sys
from collections import Counter, defaultdict, deque
from collections import defaultdict, deque
from collections.abc import Generator
from functools import cache, reduce
from typing import TypeVar
from warnings import warn
import pytensor
import pytensor.scalar.basic as ps
from pytensor import clone_replace, compile
from pytensor.compile.function.types import Supervisor
from pytensor.compile.mode import get_target_language
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort
from pytensor.graph.basic import Apply, Variable, ancestors
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import Output
from pytensor.graph.rewriting.basic import (
......@@ -43,7 +44,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
)
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant
from pytensor.tensor.variable import TensorConstant, TensorVariable
class InplaceElemwiseOptimizer(GraphRewriter):
......@@ -51,31 +52,9 @@ class InplaceElemwiseOptimizer(GraphRewriter):
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
def __init__(self, OP):
self.op = OP
def add_requirements(self, fgraph):
from pytensor.graph.destroyhandler import DestroyHandler
fgraph.attach_feature(DestroyHandler())
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, prof["opt"].op, file=stream)
for k in [
"node_before",
"nb_call_replace",
"nb_call_validate",
"nb_inconsistent",
]:
print(blanc, k, prof[k], file=stream)
ndim = prof["ndim"]
if ndim:
print(blanc, "ndim", "nb", file=stream)
for n in sorted(ndim):
print(blanc, n, ndim[n], file=stream)
def apply(self, fgraph):
r"""
......@@ -92,8 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
# We should not validate too often as this takes too much time to
# execute!
# We should not validate too often as this takes too much time to execute!
# It is the _dfs_toposort() fct in pytensor/graph/destroyhandler.py
# that takes so much time.
# Should we try to use another lib that does toposort?
......@@ -111,244 +89,199 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there.
# We execute `validate` after this number of change.
# 2025: The above comment is not specific to Elemwise, if we have concerns about this approach, we should
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
# We can consider restricting inputs with static shapes that are large enough.
def create_inplace_node(node, inplace_pattern):
op = node.op
scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scalar_op = type(scalar_op)(
ps.transfer_type(
*[
inplace_pattern.get(i, None)
for i in range(len(node.outputs))
]
)
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
if config.tensor__insert_inplace_optimizer_validate_nb != -1:
warn(
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.",
FutureWarning,
)
prof = {
"opt": self,
"node_before": len(fgraph.apply_nodes),
"nb_call_replace": 0,
"nb_call_validate": 0,
"nb_eager_inconsistent": 0,
"nb_inconsistent": 0,
"ndim": Counter(),
"nb_replaced": 0,
}
large_graph = len(fgraph.apply_nodes) > 500
check_each_change = config.tensor__insert_inplace_optimizer_validate_nb
if check_each_change == -1:
if len(fgraph.apply_nodes) > 500:
check_each_change = 10
else:
check_each_change = 1
nb_change_no_validate = 0
chk = fgraph.checkpoint()
if fgraph.update_mapping:
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
else:
update_outs = []
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = list(
protected_inputs = set(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
protected_inputs.extend(fgraph.outputs)
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
if not isinstance(op, self.op):
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
protected_inputs.update(fgraph.outputs)
root_destroyer = fgraph.destroy_handler.root_destroyer
update_mapping = fgraph.update_mapping or {}
op_updates: dict[TensorVariable, TensorVariable] = {
out: fgraph.inputs[update_mapping[out_idx]]
for out_idx, out in enumerate(fgraph.outputs)
if (
check_each_change != 1
and
# If multiple outputs, they must all have the same size,
# so only check the first.
getattr(node.outputs[0].type, "ndim", -1) == 0
):
out_idx in update_mapping
and out.owner
and isinstance(out.owner.op, Elemwise)
)
}
set_op_updates = set(op_updates.keys())
for node in fgraph.toposort():
if not isinstance(node.op, Elemwise) or node.op.destroy_map:
continue
if op.inplace_pattern:
# Maybe this isn't needed anymore, but I don't want to
# rish regression now. This case only happen if the
# original node add already some inplace patter and we
# still try to add more pattern.
# If big graph and the outputs are scalar, do not make it inplace.
if large_graph and all(node.outputs[0].type.broadcastable):
continue
baseline = op.inplace_pattern
candidate_outputs = [
i for i in range(len(node.outputs)) if i not in baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
# inplace target.
# Remove here as faster.
candidate_inputs = [
i
for i in range(len(node.inputs))
if i not in baseline.values()
and not isinstance(node.inputs[i], Constant)
# the next line should not be costly most of the time.
and not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
else:
baseline = []
candidate_outputs = range(len(node.outputs))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs = [
i
for i in range(len(node.inputs))
if not isinstance(node.inputs[i], Constant)
and not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
candidate_inputs = [
(node.inputs.index(inp), inp)
for inp in inplace_candidates(
fgraph,
node.inputs,
protected_inputs=protected_inputs,
)
]
if not candidate_inputs:
return []
candidate_pairs = [
((o, out), (i, inp))
for o, out in enumerate(node.outputs)
for i, inp in candidate_inputs
if inp.type == out.type
]
if not candidate_pairs:
continue
verbose = False
raised_warning = not verbose
for candidate_output in candidate_outputs:
# If the output of the node can be established as an update
# output of the fgraph, visit the candidate_inputs in an order
# that will improve the chances of making the node operate
# inplace on the input it's meant to update
candidate_out_var = node.outputs[candidate_output]
sorted_candidate_inputs = candidate_inputs
if candidate_out_var in update_outs:
# The candidate output is an update. Sort the
# variables in candidate_inputs in the following order:
# - Vars corresponding to the actual updated input
# (best case scenario is for the node that procudes
# an update to operate inplace on the variable to
# update)
# - Vars computed inplace on the updates input (second
# best scenario if for the node to work inplace on
# a variable obtained by a chain of inplace on the
# variable to update. In some cases, this will be
# equivalent to operating inplace on the variable to
# update)
# - Remaining variables
updated_inputs = []
for i, f_out in enumerate(fgraph.outputs):
if f_out is candidate_out_var and i in fgraph.update_mapping:
updated_inp_idx = fgraph.update_mapping[i]
updated_inputs.append(fgraph.inputs[updated_inp_idx])
updated_vars = []
vars_from_inplace = []
other_vars = []
for inp_idx in candidate_inputs:
inp = node.inputs[inp_idx]
if inp in updated_inputs:
# the candidate input is the actual updated input
updated_vars.append(inp_idx)
elif (
hasattr(fgraph, "destroy_handler")
and inp.owner
and any(
fgraph.destroy_handler.root_destroyer.get(up_inp, None)
is inp.owner
for up_inp in updated_inputs
)
sorted_candidate_pairs = candidate_pairs
if op_updates and (node_updates := set(node.outputs) & set_op_updates):
# If the fgraph has updates, we try to prioritize in-placing on the pairs that correspond to the update
direct_update_pairs = []
indirect_update_pairs = []
other_update_pairs = []
for pair in candidate_pairs:
((o, out), (i, inp)) = pair
if out in node_updates:
direct_update_inp = op_updates[out]
if direct_update_inp is inp:
# This pair is the whole graph update
direct_update_pairs.append(pair)
continue
elif (inp_node := inp.owner) is not None and any(
root_destroyer.get(up_inp, None) is inp_node
for up_inp in op_updates.values()
):
# the candidate input is a variable computed
# inplace on the updated input via a sequence of
# one or more inplace operations
vars_from_inplace.append(inp_idx)
else:
other_vars.append(inp_idx)
# This pair connects to an updated input
indirect_update_pairs.append(pair)
continue
other_update_pairs.append(pair)
sorted_candidate_inputs = (
updated_vars + vars_from_inplace + other_vars
)
sorted_candidate_pairs = (
direct_update_pairs + indirect_update_pairs + other_update_pairs
)
for candidate_input in sorted_candidate_inputs:
# remove inputs that don't have the same dtype as the output
if (
node.inputs[candidate_input].type
!= node.outputs[candidate_output].type
):
continue
# Try in-placing all outputs at once
tried_inputs = set()
inplace_pattern = {}
for (o, _), (i, _) in sorted_candidate_pairs:
if o not in inplace_pattern and i not in tried_inputs:
inplace_pattern[o] = [i]
tried_inputs.add(i)
inplace_node = create_inplace_node(node, inplace_pattern)
if inplace_node.op.destroy_map == inplace_pattern:
replacements = tuple(zip(node.outputs, inplace_node.outputs))
try:
fgraph.replace_all_validate(
replacements, reason="inplace_elemwise_optimizer"
)
except InconsistencyError:
prof["nb_eager_inconsistent"] += 1
else:
prof["nb_replaced"] += 1
continue
inplace_pattern = dict(baseline)
inplace_pattern[candidate_output] = candidate_input
try:
if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace(
ps.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scal = op.scalar_op.__class__(
ps.transfer_type(
*[
inplace_pattern.get(i, None)
for i in range(len(node.outputs))
]
)
# If it fails or doesn't match the desired inplace pattern, try one output/input at a time
tried_inputs = set()
inplace_pattern = {}
replaced = False
for (o, _), (i, _) in sorted_candidate_pairs:
if o not in inplace_pattern and i not in tried_inputs:
inplace_pattern[o] = [i]
tried_inputs.add(i)
inplace_node = create_inplace_node(node, inplace_pattern)
if inplace_node.op.destroy_map != inplace_pattern:
# This Op can't respect this partial inplace pattern,
# We assume it can't support any other cases
break
else:
replacements = tuple(zip(node.outputs, inplace_node.outputs))
try:
fgraph.replace_all_validate(
replacements, reason="inplace_elemwise_optimizer"
)
new_outputs = self.op(new_scal, inplace_pattern)(
*node.inputs, return_list=True
)
new_node = new_outputs[0].owner
node = inplace_node
replaced = True
except InconsistencyError:
prof["nb_inconsistent"] += 1
# The input, not the output caused inconsistencies
inplace_pattern.pop(o)
prof["nb_replaced"] += replaced
for r, new_r in zip(node.outputs, new_outputs, strict=True):
prof["nb_call_replace"] += 1
fgraph.replace(
r, new_r, reason="inplace_elemwise_optimizer"
)
nb_change_no_validate += 1
prof["ndim"][candidate_out_var.ndim] += 1
if nb_change_no_validate >= check_each_change:
prof["nb_call_validate"] += 1
fgraph.validate()
chk = fgraph.checkpoint()
nb_change_no_validate = 0
except (ValueError, InconsistencyError) as e:
prof["nb_inconsistent"] += 1
if check_each_change != 1 and not raised_warning:
print( # noqa: T201
(
"Some inplace rewriting was not "
"performed due to an unexpected error:"
),
file=sys.stderr,
)
print(e, file=sys.stderr) # noqa: T201
raised_warning = True
fgraph.revert(chk)
continue
candidate_inputs.remove(candidate_input)
node = new_node
baseline = inplace_pattern
break
if nb_change_no_validate > 0:
try:
fgraph.validate()
except Exception:
if not raised_warning:
print( # noqa: T201
(
"Some inplace rewriting was not "
"performed due to an unexpected error"
),
file=sys.stderr,
)
fgraph.revert(chk)
return prof
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, file=stream)
for k in [
"node_before",
"nb_eager_inconsistent",
"nb_inconsistent",
"nb_replaced",
]:
print(blanc, k, prof[k], file=stream)
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(
f"{' ' * level}{self.__class__.__name__} ({self.op})",
f"{' ' * level}{self.__class__.__name__}",
file=stream,
)
return inplace_elemwise_optimizer
inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise)
compile.optdb.register(
"inplace_elemwise_opt",
inplace_elemwise_optimizer,
"inplace_opt", # for historic reason
"inplace_elemwise",
InplaceElemwiseOptimizer(),
"inplace_elemwise_opt", # for historic reason
"inplace_elemwise_optimizer",
"fast_run",
"inplace",
......
......@@ -8,6 +8,7 @@ from pytensor import In, shared
from pytensor import scalar as ps
from pytensor import tensor as pt
from pytensor.compile.function import function
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
......@@ -1529,3 +1530,31 @@ def test_constant_fold_branches_add_mul(op):
new_out = rewrite_graph(out, include=("add_mul_fusion",))
assert len(new_out.owner.inputs) == 3
assert equal_computations([new_out], [op(py_op(a, b), c, x)])
def test_InplaceElemwiseOptimizer_bug():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1420
# This graph fails if InplaceElemwiseOptimizer were to try to skip `fgraph.validate`
# in between two invalid inplace rewrites.
z = pt.matrix("z")
z1 = ps.float64("z1")
z2 = ps.float64("z2")
out1, out2 = Elemwise(ps.Composite([z1, z2], [z1 + z2, z2 - z1]))(z[1:], z[:-1])
out = pt.exp(z[1:-1]).sum() + out1.sum() + out2.sum()
# Add 500 unrelated nodes to trigger the old special behavior
irrelevant_outs = [pt.specify_shape(z, (4, 4)) for _ in range(500)]
fgraph = FunctionGraph(inputs=[z], outputs=[out, *irrelevant_outs], clone=False)
add_supervisor_to_fgraph(fgraph, [In(z)])
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",))
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns(
FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论