提交 1ebd078a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Reuse Elemwise inplace machinery for Blockwise

上级 1d94ed68
...@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb ...@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
...@@ -11,6 +11,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -11,6 +11,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer
from pytensor.tensor.shape import Reshape from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -260,68 +261,77 @@ def local_blockwise_of_subtensor(fgraph, node): ...@@ -260,68 +261,77 @@ def local_blockwise_of_subtensor(fgraph, node):
return [x[(*none_slices, *core_idxs)]] return [x[(*none_slices, *core_idxs)]]
@node_rewriter(tracks=[Blockwise], inplace=True) class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
def blockwise_inplace(fgraph, node): op = Blockwise
blockwise_op = node.op
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
if blockwise_op.destroy_map: blockwise_op = node.op
# Op already has inplace batch_ndim = blockwise_op.batch_ndim(node)
return out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
inputs = node.inputs
# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node) candidate_inputs = set(
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim] inplace_candidates(
fgraph,
inputs = node.inputs [
candidate_inputs = set( inp
inplace_candidates( for inp in inputs
fgraph, if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
[ ],
inp protected_inputs=protected_inputs,
for inp in inputs )
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
],
) )
)
allowed_inplace_inputs = [
i for i, inp in enumerate(inputs) if inp in candidate_inputs
]
if not allowed_inplace_inputs: allowed_inplace_inputs = [
return None i for i, inp in enumerate(inputs) if inp in candidate_inputs
]
destroy_map = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
).destroy_map
if not destroy_map:
return []
outputs = node.outputs
return [
((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx]))
for out_idx, inp_idxs in destroy_map.items()
for inp_idx in inp_idxs
]
inplace_core_op = blockwise_op.core_op.inplace_on_inputs( def create_inplace_node(self, node, inplace_pattern):
allowed_inplace_inputs=allowed_inplace_inputs blockwise_op = node.op
) allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values())
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
if not inplace_core_op.destroy_map: if not inplace_core_op.destroy_map:
return None return node
# Check Op is not trying to inplace on non-candidate inputs # Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values(): for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs: for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs: if destroyed_input not in allowed_inplace_inputs:
raise ValueError( raise ValueError(
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}" f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
) )
# Recreate core_op with inplace # Recreate core_op with inplace
inplace_blockwise_op = Blockwise( inplace_blockwise_op = type(blockwise_op)(
core_op=inplace_core_op, core_op=inplace_core_op,
signature=blockwise_op.signature, signature=blockwise_op.signature,
name=blockwise_op.name, name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec, gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map, destroy_map=inplace_core_op.destroy_map,
) )
out = inplace_blockwise_op.make_node(*node.inputs).outputs return inplace_blockwise_op.make_node(*node.inputs)
copy_stack_trace(node.outputs, out)
return out
optdb.register( optdb.register(
"blockwise_inplace", "blockwise_inplace",
in2out(blockwise_inplace), InplaceBlockwiseOptimizer(),
"fast_run", "fast_run",
"inplace", "inplace",
position=50.1, position=50.1,
......
import abc
import itertools import itertools
import operator import operator
import sys import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Generator from collections.abc import Generator, Sequence
from functools import cache, reduce from functools import cache, reduce
from typing import TypeVar from typing import TypeVar
from warnings import warn from warnings import warn
...@@ -12,7 +13,7 @@ from pytensor import clone_replace, compile ...@@ -12,7 +13,7 @@ from pytensor import clone_replace, compile
from pytensor.compile.function.types import Supervisor from pytensor.compile.function.types import Supervisor
from pytensor.compile.mode import get_target_language from pytensor.compile.mode import get_target_language
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Apply, Variable, ancestors from pytensor.graph.basic import Apply, Variable, ancestors
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
...@@ -47,22 +48,31 @@ from pytensor.tensor.shape import shape_padleft ...@@ -47,22 +48,31 @@ from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
class InplaceElemwiseOptimizer(GraphRewriter): class InplaceGraphOptimizer(GraphRewriter):
r""" op: type[Op]
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
@abc.abstractmethod
def filter_candidate_pairs(
self, fgraph: FunctionGraph, node: Apply, protected_inputs: Sequence[Variable]
) -> Sequence[tuple[tuple[int, Variable], tuple[int, Variable]]]:
pass
@abc.abstractmethod
def create_inplace_node(
self, node: Apply, inplace_pattern: dict[int, Sequence[int]]
) -> Apply:
pass
def apply(self, fgraph): def apply(self, fgraph):
r""" r"""
Attempts to replace all `Elemwise`\s by versions of them that operate Attempts to replace all `Op`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered, inplace. It operates greedily: for each `Op` that is encountered,
for each output, it tries each input to see if it can operate inplace it tries to inplace all the valid inputs at once (if the Op supports it),
on that input. If so, it makes the change and goes to the next output if that fails, it tries to inplace one input at a time.
or `Elemwise`.
Examples Examples
-------- --------
...@@ -93,36 +103,13 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -93,36 +103,13 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# tackle them in a more general way. The whole try/except approach is probably suboptimal. # tackle them in a more general way. The whole try/except approach is probably suboptimal.
# We can consider restricting inputs with static shapes that are large enough. # We can consider restricting inputs with static shapes that are large enough.
def create_inplace_node(node, inplace_pattern):
op = node.op
scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scalar_op = type(scalar_op)(
ps.transfer_type(
*[
inplace_pattern.get(i, None)
for i in range(len(node.outputs))
]
)
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
if config.tensor__insert_inplace_optimizer_validate_nb != -1: if config.tensor__insert_inplace_optimizer_validate_nb != -1:
warn( warn(
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.", "tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.",
FutureWarning, FutureWarning,
) )
reason = f"{self.op}_inplace_optimizer"
prof = { prof = {
"opt": self, "opt": self,
"node_before": len(fgraph.apply_nodes), "node_before": len(fgraph.apply_nodes),
...@@ -140,6 +127,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -140,6 +127,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
protected_inputs.update(fgraph.outputs) protected_inputs.update(fgraph.outputs)
root_destroyer = fgraph.destroy_handler.root_destroyer root_destroyer = fgraph.destroy_handler.root_destroyer
self_op = self.op
update_mapping = fgraph.update_mapping or {} update_mapping = fgraph.update_mapping or {}
op_updates: dict[TensorVariable, TensorVariable] = { op_updates: dict[TensorVariable, TensorVariable] = {
out: fgraph.inputs[update_mapping[out_idx]] out: fgraph.inputs[update_mapping[out_idx]]
...@@ -147,36 +135,22 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -147,36 +135,22 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if ( if (
out_idx in update_mapping out_idx in update_mapping
and out.owner and out.owner
and isinstance(out.owner.op, Elemwise) and isinstance(out.owner.op, self_op)
) )
} }
set_op_updates = set(op_updates.keys()) set_op_updates = set(op_updates.keys())
for node in fgraph.toposort(): for node in fgraph.toposort():
if not isinstance(node.op, Elemwise) or node.op.destroy_map: if not isinstance(node.op, self_op) or node.op.destroy_map:
continue continue
# If big graph and the outputs are scalar, do not make it inplace. # If big graph and the outputs are scalar, do not make it inplace.
if large_graph and all(node.outputs[0].type.broadcastable): if large_graph and all(node.outputs[0].type.broadcastable):
continue continue
candidate_inputs = [ candidate_pairs = self.filter_candidate_pairs(
(node.inputs.index(inp), inp) fgraph, node, protected_inputs
for inp in inplace_candidates( )
fgraph,
node.inputs,
protected_inputs=protected_inputs,
)
]
if not candidate_inputs:
return []
candidate_pairs = [
((o, out), (i, inp))
for o, out in enumerate(node.outputs)
for i, inp in candidate_inputs
if inp.type == out.type
]
if not candidate_pairs: if not candidate_pairs:
continue continue
...@@ -216,13 +190,11 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -216,13 +190,11 @@ class InplaceElemwiseOptimizer(GraphRewriter):
inplace_pattern[o] = [i] inplace_pattern[o] = [i]
tried_inputs.add(i) tried_inputs.add(i)
inplace_node = create_inplace_node(node, inplace_pattern) inplace_node = self.create_inplace_node(node, inplace_pattern)
if inplace_node.op.destroy_map == inplace_pattern: if inplace_node.op.destroy_map == inplace_pattern:
replacements = tuple(zip(node.outputs, inplace_node.outputs)) replacements = tuple(zip(node.outputs, inplace_node.outputs))
try: try:
fgraph.replace_all_validate( fgraph.replace_all_validate(replacements, reason=reason)
replacements, reason="inplace_elemwise_optimizer"
)
except InconsistencyError: except InconsistencyError:
prof["nb_eager_inconsistent"] += 1 prof["nb_eager_inconsistent"] += 1
else: else:
...@@ -238,7 +210,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -238,7 +210,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
inplace_pattern[o] = [i] inplace_pattern[o] = [i]
tried_inputs.add(i) tried_inputs.add(i)
inplace_node = create_inplace_node(node, inplace_pattern) inplace_node = self.create_inplace_node(node, inplace_pattern)
if inplace_node.op.destroy_map != inplace_pattern: if inplace_node.op.destroy_map != inplace_pattern:
# This Op can't respect this partial inplace pattern, # This Op can't respect this partial inplace pattern,
# We assume it can't support any other cases # We assume it can't support any other cases
...@@ -246,9 +218,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -246,9 +218,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
else: else:
replacements = tuple(zip(node.outputs, inplace_node.outputs)) replacements = tuple(zip(node.outputs, inplace_node.outputs))
try: try:
fgraph.replace_all_validate( fgraph.replace_all_validate(replacements, reason=reason)
replacements, reason="inplace_elemwise_optimizer"
)
node = inplace_node node = inplace_node
replaced = True replaced = True
except InconsistencyError: except InconsistencyError:
...@@ -278,6 +248,50 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -278,6 +248,50 @@ class InplaceElemwiseOptimizer(GraphRewriter):
) )
class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
op = Elemwise
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
candidate_inputs = [
(node.inputs.index(inp), inp)
for inp in inplace_candidates(
fgraph,
node.inputs,
protected_inputs=protected_inputs,
)
]
if not candidate_inputs:
return []
return [
((o, out), (i, inp))
for o, out in enumerate(node.outputs)
for i, inp in candidate_inputs
if inp.type == out.type
]
def create_inplace_node(self, node, inplace_pattern):
op = node.op
scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scalar_op = type(scalar_op)(
ps.transfer_type(
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
)
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
compile.optdb.register( compile.optdb.register(
"inplace_elemwise", "inplace_elemwise",
InplaceElemwiseOptimizer(), InplaceElemwiseOptimizer(),
......
...@@ -8,11 +8,21 @@ import scipy.linalg ...@@ -8,11 +8,21 @@ import scipy.linalg
import pytensor import pytensor
from pytensor import In, config, function, scan from pytensor import In, config, function, scan
from pytensor.compile import get_default_mode, get_mode from pytensor.compile import get_default_mode, get_mode
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, FunctionGraph, Op, rewrite_graph
from pytensor.graph.replace import vectorize_graph, vectorize_node from pytensor.graph.replace import vectorize_graph, vectorize_node
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector from pytensor.tensor import (
diagonal,
dmatrix,
log,
matrices,
ones_like,
scalar,
tensor,
vector,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
...@@ -698,3 +708,57 @@ def test_scan_gradient_core_type(): ...@@ -698,3 +708,57 @@ def test_scan_gradient_core_type():
grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}), grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}),
np.ones((4, n_steps, 1)), np.ones((4, n_steps, 1)),
) )
def test_partial_inplace():
class CoreOp(Op):
__props__ = ("inplace",)
def __init__(self, inplace):
self.inplace = tuple(inplace)
self.destroy_map = {i: [i] for i in inplace}
def inplace_on_inputs(self, allowed_inplace_inputs):
return type(self)(inplace=allowed_inplace_inputs)
def make_node(self, x, y, z):
return Apply(self, [x, y, z], [x.type(), y.type(), z.type()])
def perform(self, node, inputs, outputs):
[x, y, z] = inputs
if 0 not in self.inplace:
x = x.copy()
if 1 not in self.inplace:
y = y.copy()
if 2 not in self.inplace:
z = z.copy()
outputs[0][0] = x
outputs[1][0] = y
outputs[2][0] = z
core_op = CoreOp(inplace=())
blockwise_op = Blockwise(core_op, signature="(),(),()->(),(),()")
x, y, z = matrices("xyz")
# All can be inplaced
out = blockwise_op(x.T, y.T, z.T)
fgraph = FunctionGraph([x, y, z], out)
add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs])
rewrite_graph(fgraph, include=("inplace",))
assert fgraph.outputs[0].owner.op.destroy_map == {0: [0], 1: [1], 2: [2]}
# Only x, z can be inplaced, y is protected
out = blockwise_op(x.T, y.T, z.T)
fgraph = FunctionGraph([x, y, z], out)
add_supervisor_to_fgraph(
fgraph, [In(inp, mutable=(i % 2) == 0) for i, inp in enumerate(fgraph.inputs)]
)
rewrite_graph(fgraph, include=("inplace",))
assert fgraph.outputs[0].owner.op.destroy_map == {0: [0], 2: [2]}
# Only y can be inplaced, x is reused for first and third outputs
out = blockwise_op(x.T, y.T, x.T)
fgraph = FunctionGraph([x, y, z], out)
add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs])
rewrite_graph(fgraph, include=("inplace",))
assert fgraph.outputs[0].owner.op.destroy_map == {1: [1]}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论