提交 d9e8728a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not skip validation between consecutive Elemwise inplace replacements

上级 7d091be3
......@@ -7,7 +7,6 @@ and inplace operations.
import itertools
from collections import deque
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper
......@@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler):
return droot, impact, root_destroyer
def fast_inplace_check(fgraph, inputs):
def inplace_candidates(fgraph, inputs, protected_inputs=None):
"""
Return the variables in inputs that are possible candidate for as inputs of
inplace operation.
......@@ -234,22 +233,28 @@ def fast_inplace_check(fgraph, inputs):
Inputs Variable that you want to use as inplace destination.
"""
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = list(
if protected_inputs is None:
from pytensor.compile.function.types import Supervisor
protected_inputs = set(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
protected_inputs.extend(fgraph.outputs)
inputs = [
i
for i in inputs
if not isinstance(i, Constant)
and not fgraph.has_destroyers([i])
and i not in protected_inputs
protected_inputs.update(fgraph.outputs)
has_destroyers = fgraph.has_destroyers
return [
inp
# Remove duplicates, while preserving order by using dict.fromkeys
for inp in dict.fromkeys(inputs)
if (
not isinstance(inp, Constant)
and inp not in protected_inputs
and not has_destroyers([inp])
)
]
return inputs
class DestroyHandler(Bookkeeper):
......
import itertools
from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
......@@ -274,25 +272,19 @@ def blockwise_inplace(fgraph, node):
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
inputs = node.inputs
candidate_inputs = set(
inplace_candidates(
fgraph,
[
inp
for inp in inputs
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
],
)
)
allowed_inplace_inputs = [
i for i, inp in enumerate(inputs) if inp in candidate_inputs
]
if not allowed_inplace_inputs:
......
......@@ -8,6 +8,7 @@ from pytensor import In, shared
from pytensor import scalar as ps
from pytensor import tensor as pt
from pytensor.compile.function import function
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
......@@ -1529,3 +1530,31 @@ def test_constant_fold_branches_add_mul(op):
new_out = rewrite_graph(out, include=("add_mul_fusion",))
assert len(new_out.owner.inputs) == 3
assert equal_computations([new_out], [op(py_op(a, b), c, x)])
def test_InplaceElemwiseOptimizer_bug():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1420
# This graph fails if InplaceElemwiseOptimizer were to try to skip `fgraph.validate`
# in between two invalid inplace rewrites.
z = pt.matrix("z")
z1 = ps.float64("z1")
z2 = ps.float64("z2")
out1, out2 = Elemwise(ps.Composite([z1, z2], [z1 + z2, z2 - z1]))(z[1:], z[:-1])
out = pt.exp(z[1:-1]).sum() + out1.sum() + out2.sum()
# Add 500 unrelated nodes to trigger the old special behavior
irrelevant_outs = [pt.specify_shape(z, (4, 4)) for _ in range(500)]
fgraph = FunctionGraph(inputs=[z], outputs=[out, *irrelevant_outs], clone=False)
add_supervisor_to_fgraph(fgraph, [In(z)])
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",))
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns(
FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论