提交 1ebd078a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Reuse Elemwise inplace machinery for Blockwise

上级 1d94ed68
......@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
......@@ -11,6 +11,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -260,19 +261,15 @@ def local_blockwise_of_subtensor(fgraph, node):
return [x[(*none_slices, *core_idxs)]]
@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
if blockwise_op.destroy_map:
# Op already has inplace
return
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
op = Blockwise
# Find out valid inputs for inplacing
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
blockwise_op = node.op
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
inputs = node.inputs
candidate_inputs = set(
inplace_candidates(
fgraph,
......@@ -281,21 +278,36 @@ def blockwise_inplace(fgraph, node):
for inp in inputs
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
],
protected_inputs=protected_inputs,
)
)
allowed_inplace_inputs = [
i for i, inp in enumerate(inputs) if inp in candidate_inputs
]
destroy_map = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
).destroy_map
if not allowed_inplace_inputs:
return None
if not destroy_map:
return []
outputs = node.outputs
return [
((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx]))
for out_idx, inp_idxs in destroy_map.items()
for inp_idx in inp_idxs
]
def create_inplace_node(self, node, inplace_pattern):
blockwise_op = node.op
allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values())
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
if not inplace_core_op.destroy_map:
return None
return node
# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
......@@ -306,7 +318,7 @@ def blockwise_inplace(fgraph, node):
)
# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
inplace_blockwise_op = type(blockwise_op)(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
......@@ -314,14 +326,12 @@ def blockwise_inplace(fgraph, node):
destroy_map=inplace_core_op.destroy_map,
)
out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out
return inplace_blockwise_op.make_node(*node.inputs)
optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
InplaceBlockwiseOptimizer(),
"fast_run",
"inplace",
position=50.1,
......
import abc
import itertools
import operator
import sys
from collections import defaultdict, deque
from collections.abc import Generator
from collections.abc import Generator, Sequence
from functools import cache, reduce
from typing import TypeVar
from warnings import warn
......@@ -12,7 +13,7 @@ from pytensor import clone_replace, compile
from pytensor.compile.function.types import Supervisor
from pytensor.compile.mode import get_target_language
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Apply, Variable, ancestors
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
from pytensor.graph.features import ReplaceValidate
......@@ -47,22 +48,31 @@ from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, TensorVariable
class InplaceElemwiseOptimizer(GraphRewriter):
r"""
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
class InplaceGraphOptimizer(GraphRewriter):
op: type[Op]
def add_requirements(self, fgraph):
fgraph.attach_feature(DestroyHandler())
@abc.abstractmethod
def filter_candidate_pairs(
self, fgraph: FunctionGraph, node: Apply, protected_inputs: Sequence[Variable]
) -> Sequence[tuple[tuple[int, Variable], tuple[int, Variable]]]:
pass
@abc.abstractmethod
def create_inplace_node(
self, node: Apply, inplace_pattern: dict[int, Sequence[int]]
) -> Apply:
pass
def apply(self, fgraph):
r"""
Attempts to replace all `Elemwise`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered,
for each output, it tries each input to see if it can operate inplace
on that input. If so, it makes the change and goes to the next output
or `Elemwise`.
Attempts to replace all `Op`\s by versions of them that operate
inplace. It operates greedily: for each `Op` that is encountered,
it tries to inplace all the valid inputs at once (if the Op supports it),
if that fails, it tries to inplace one input at a time.
Examples
--------
......@@ -93,36 +103,13 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
# We can consider restricting inputs with static shapes that are large enough.
def create_inplace_node(node, inplace_pattern):
op = node.op
scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scalar_op = type(scalar_op)(
ps.transfer_type(
*[
inplace_pattern.get(i, None)
for i in range(len(node.outputs))
]
)
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
if config.tensor__insert_inplace_optimizer_validate_nb != -1:
warn(
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.",
FutureWarning,
)
reason = f"{self.op}_inplace_optimizer"
prof = {
"opt": self,
"node_before": len(fgraph.apply_nodes),
......@@ -140,6 +127,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
protected_inputs.update(fgraph.outputs)
root_destroyer = fgraph.destroy_handler.root_destroyer
self_op = self.op
update_mapping = fgraph.update_mapping or {}
op_updates: dict[TensorVariable, TensorVariable] = {
out: fgraph.inputs[update_mapping[out_idx]]
......@@ -147,36 +135,22 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if (
out_idx in update_mapping
and out.owner
and isinstance(out.owner.op, Elemwise)
and isinstance(out.owner.op, self_op)
)
}
set_op_updates = set(op_updates.keys())
for node in fgraph.toposort():
if not isinstance(node.op, Elemwise) or node.op.destroy_map:
if not isinstance(node.op, self_op) or node.op.destroy_map:
continue
# If big graph and the outputs are scalar, do not make it inplace.
if large_graph and all(node.outputs[0].type.broadcastable):
continue
candidate_inputs = [
(node.inputs.index(inp), inp)
for inp in inplace_candidates(
fgraph,
node.inputs,
protected_inputs=protected_inputs,
candidate_pairs = self.filter_candidate_pairs(
fgraph, node, protected_inputs
)
]
if not candidate_inputs:
return []
candidate_pairs = [
((o, out), (i, inp))
for o, out in enumerate(node.outputs)
for i, inp in candidate_inputs
if inp.type == out.type
]
if not candidate_pairs:
continue
......@@ -216,13 +190,11 @@ class InplaceElemwiseOptimizer(GraphRewriter):
inplace_pattern[o] = [i]
tried_inputs.add(i)
inplace_node = create_inplace_node(node, inplace_pattern)
inplace_node = self.create_inplace_node(node, inplace_pattern)
if inplace_node.op.destroy_map == inplace_pattern:
replacements = tuple(zip(node.outputs, inplace_node.outputs))
try:
fgraph.replace_all_validate(
replacements, reason="inplace_elemwise_optimizer"
)
fgraph.replace_all_validate(replacements, reason=reason)
except InconsistencyError:
prof["nb_eager_inconsistent"] += 1
else:
......@@ -238,7 +210,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
inplace_pattern[o] = [i]
tried_inputs.add(i)
inplace_node = create_inplace_node(node, inplace_pattern)
inplace_node = self.create_inplace_node(node, inplace_pattern)
if inplace_node.op.destroy_map != inplace_pattern:
# This Op can't respect this partial inplace pattern,
# We assume it can't support any other cases
......@@ -246,9 +218,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
else:
replacements = tuple(zip(node.outputs, inplace_node.outputs))
try:
fgraph.replace_all_validate(
replacements, reason="inplace_elemwise_optimizer"
)
fgraph.replace_all_validate(replacements, reason=reason)
node = inplace_node
replaced = True
except InconsistencyError:
......@@ -278,6 +248,50 @@ class InplaceElemwiseOptimizer(GraphRewriter):
)
class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
op = Elemwise
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
candidate_inputs = [
(node.inputs.index(inp), inp)
for inp in inplace_candidates(
fgraph,
node.inputs,
protected_inputs=protected_inputs,
)
]
if not candidate_inputs:
return []
return [
((o, out), (i, inp))
for o, out in enumerate(node.outputs)
for i, inp in candidate_inputs
if inp.type == out.type
]
def create_inplace_node(self, node, inplace_pattern):
op = node.op
scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scalar_op = type(scalar_op)(
ps.transfer_type(
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
)
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
compile.optdb.register(
"inplace_elemwise",
InplaceElemwiseOptimizer(),
......
......@@ -8,11 +8,21 @@ import scipy.linalg
import pytensor
from pytensor import In, config, function, scan
from pytensor.compile import get_default_mode, get_mode
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph import Apply, FunctionGraph, Op, rewrite_graph
from pytensor.graph.replace import vectorize_graph, vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
from pytensor.tensor import (
diagonal,
dmatrix,
log,
matrices,
ones_like,
scalar,
tensor,
vector,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
......@@ -698,3 +708,57 @@ def test_scan_gradient_core_type():
grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}),
np.ones((4, n_steps, 1)),
)
def test_partial_inplace():
class CoreOp(Op):
__props__ = ("inplace",)
def __init__(self, inplace):
self.inplace = tuple(inplace)
self.destroy_map = {i: [i] for i in inplace}
def inplace_on_inputs(self, allowed_inplace_inputs):
return type(self)(inplace=allowed_inplace_inputs)
def make_node(self, x, y, z):
return Apply(self, [x, y, z], [x.type(), y.type(), z.type()])
def perform(self, node, inputs, outputs):
[x, y, z] = inputs
if 0 not in self.inplace:
x = x.copy()
if 1 not in self.inplace:
y = y.copy()
if 2 not in self.inplace:
z = z.copy()
outputs[0][0] = x
outputs[1][0] = y
outputs[2][0] = z
core_op = CoreOp(inplace=())
blockwise_op = Blockwise(core_op, signature="(),(),()->(),(),()")
x, y, z = matrices("xyz")
# All can be inplaced
out = blockwise_op(x.T, y.T, z.T)
fgraph = FunctionGraph([x, y, z], out)
add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs])
rewrite_graph(fgraph, include=("inplace",))
assert fgraph.outputs[0].owner.op.destroy_map == {0: [0], 1: [1], 2: [2]}
# Only x, z can be inplaced, y is protected
out = blockwise_op(x.T, y.T, z.T)
fgraph = FunctionGraph([x, y, z], out)
add_supervisor_to_fgraph(
fgraph, [In(inp, mutable=(i % 2) == 0) for i, inp in enumerate(fgraph.inputs)]
)
rewrite_graph(fgraph, include=("inplace",))
assert fgraph.outputs[0].owner.op.destroy_map == {0: [0], 2: [2]}
# Only y can be inplaced, x is reused for first and third outputs
out = blockwise_op(x.T, y.T, x.T)
fgraph = FunctionGraph([x, y, z], out)
add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs])
rewrite_graph(fgraph, include=("inplace",))
assert fgraph.outputs[0].owner.op.destroy_map == {1: [1]}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论