提交 ff092688 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Constant fold branches of variadic add/mul

上级 5a462e98
import itertools import itertools
import operator
import sys import sys
from collections import Counter, defaultdict, deque from collections import Counter, defaultdict, deque
from collections.abc import Generator from collections.abc import Generator
from functools import cache from functools import cache, reduce
from typing import TypeVar from typing import TypeVar
from warnings import warn from warnings import warn
...@@ -16,11 +17,11 @@ from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposo ...@@ -16,11 +17,11 @@ from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposo
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import Output from pytensor.graph.fg import Output
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
node_rewriter, node_rewriter,
out2in,
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
...@@ -29,13 +30,15 @@ from pytensor.tensor.basic import ( ...@@ -29,13 +30,15 @@ from pytensor.tensor.basic import (
MakeVector, MakeVector,
alloc, alloc,
cast, cast,
constant,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
alloc_like, alloc_like,
broadcasted_by,
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
) )
...@@ -542,8 +545,8 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -542,8 +545,8 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
return rval return rval
@node_rewriter([Elemwise]) @node_rewriter([add, mul])
def local_add_mul_fusion(fgraph, node): def flatten_nested_add_mul(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs. """Fuse consecutive add or mul in one such node with more inputs.
It is better to fuse add/mul that way then in a Composite node as It is better to fuse add/mul that way then in a Composite node as
...@@ -554,27 +557,16 @@ def local_add_mul_fusion(fgraph, node): ...@@ -554,27 +557,16 @@ def local_add_mul_fusion(fgraph, node):
This rewrite is almost useless after the AlgebraicCanonizer is used, This rewrite is almost useless after the AlgebraicCanonizer is used,
but it catches a few edge cases that are not canonicalized by it but it catches a few edge cases that are not canonicalized by it
""" """
if not ( s_op = node.op.scalar_op
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Add | ps.Mul)
):
return False
s_op = node.op.scalar_op.__class__
new_inp = [] new_inp = []
fused = False fused = False
nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs: for inp in node.inputs:
if ( if (
inp.owner inp.owner
and isinstance(inp.owner.op, Elemwise) and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, s_op) and inp.owner.op.scalar_op == s_op
and
# Do not duplicate the operation. # Do not duplicate the operation.
len(fgraph.clients[inp]) == 1 and len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
): ):
new_inp.extend(inp.owner.inputs) new_inp.extend(inp.owner.inputs)
fused = True fused = True
...@@ -590,7 +582,7 @@ def local_add_mul_fusion(fgraph, node): ...@@ -590,7 +582,7 @@ def local_add_mul_fusion(fgraph, node):
# Do the recursion here to help lower the number of # Do the recursion here to help lower the number of
# FusionOptimizer iteration. # FusionOptimizer iteration.
if output.owner: if output.owner:
output2 = local_add_mul_fusion.transform(fgraph, output.owner) output2 = flatten_nested_add_mul.transform(fgraph, output.owner)
if output2: if output2:
return output2 return output2
return [output] return [output]
...@@ -1237,6 +1229,76 @@ def local_inline_composite_constants(fgraph, node): ...@@ -1237,6 +1229,76 @@ def local_inline_composite_constants(fgraph, node):
return new_outputs return new_outputs
@node_rewriter(tracks=[add, mul])
def constant_fold_branches_of_add_mul(fgraph, node):
old_constants = [inp for inp in node.inputs if isinstance(inp, TensorConstant)]
if len(old_constants) <= 1:
return None
new_constants = old_constants.copy()
# Multiply constants if it doesn't result in higher intermediate memory
while True:
n_constants = len(new_constants)
if n_constants <= 1:
break
for i in range(n_constants):
reference_inp = new_constants[i]
other_inps = []
for j in range(n_constants):
if i == j:
continue
other_inp = new_constants[j]
if not broadcasted_by(reference_inp, other_inp):
other_inps.append(other_inp)
if other_inps:
python_op = operator.mul if node.op == mul else operator.add
folded_inputs = [reference_inp, *other_inps]
new_inp = constant(
reduce(python_op, (const.data for const in folded_inputs))
)
new_constants = [
new_inp,
*(inp for inp in new_constants if inp not in folded_inputs),
]
break
else: # no-break
break
if len(new_constants) == len(old_constants):
return None
non_constants = [inp for inp in node.inputs if not isinstance(inp, TensorConstant)]
new_out = node.op(
*new_constants,
*non_constants,
)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
add_mul_fusion_seqopt = SequenceDB()
compile.optdb.register(
"add_mul_fusion",
add_mul_fusion_seqopt,
"fast_run",
position=48, # Before Elemwise fusion
)
add_mul_fusion_seqopt.register(
flatten_nested_add_mul.__name__,
out2in(flatten_nested_add_mul, ignore_newtrees=False),
"fast_run",
position=0,
)
add_mul_fusion_seqopt.register(
constant_fold_branches_of_add_mul.__name__,
in2out(constant_fold_branches_of_add_mul, ignore_newtrees=True),
"fast_run",
position=1,
)
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB() fuse_seqopt = SequenceDB()
compile.optdb.register( compile.optdb.register(
...@@ -1248,14 +1310,6 @@ compile.optdb.register( ...@@ -1248,14 +1310,6 @@ compile.optdb.register(
"FusionOptimizer", "FusionOptimizer",
position=49, position=49,
) )
fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register( fuse_seqopt.register(
"composite_elemwise_fusion", "composite_elemwise_fusion",
FusionOptimizer(), FusionOptimizer(),
...@@ -1279,7 +1333,7 @@ fuse_seqopt.register( ...@@ -1279,7 +1333,7 @@ fuse_seqopt.register(
) )
fuse_seqopt.register( fuse_seqopt.register(
"local_inline_composite_constants", "local_inline_composite_constants",
in2out(local_inline_composite_constants), in2out(local_inline_composite_constants, ignore_newtrees=True),
"fast_run", "fast_run",
"fusion", "fusion",
position=20, position=20,
......
...@@ -238,6 +238,7 @@ class TestFusion: ...@@ -238,6 +238,7 @@ class TestFusion:
include=[ include=[
"canonicalize", "canonicalize",
"fusion", "fusion",
"add_mul_fusion",
"inplace", "inplace",
], ],
exclude=["cxx_only", "BlasOpt"], exclude=["cxx_only", "BlasOpt"],
...@@ -1507,3 +1508,24 @@ def test_local_useless_dimshuffle_makevector(): ...@@ -1507,3 +1508,24 @@ def test_local_useless_dimshuffle_makevector():
) )
assert y_rewritten_fg.outputs[0] == a assert y_rewritten_fg.outputs[0] == a
@pytest.mark.parametrize("op", (add, mul))
def test_constant_fold_branches_add_mul(op):
rng = np.random.default_rng()
py_op = np.add if op is add else np.multiply
x = pt.vector("x")
a = rng.normal(size=(1, 512, 5))
b = rng.normal(size=(1, 512, 1))
out = op(op(a, x), b)
new_out = rewrite_graph(out, include=("add_mul_fusion",))
assert len(new_out.owner.inputs) == 2
assert equal_computations([new_out], [op(py_op(a, b), x)])
# c shouldn't be folded as it would increase the memory usage
c = rng.normal(size=(1024, 1, 1))
out = op(op(op(a, x), c), b)
new_out = rewrite_graph(out, include=("add_mul_fusion",))
assert len(new_out.owner.inputs) == 3
assert equal_computations([new_out], [op(py_op(a, b), c, x)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论