提交 2c84b496 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup Fusion rewrites

* Move local_add_mul_fusion to `rewriting/elemwise` and remove unused/duplicated TestAddMulFusion tests * Use EquilibriumGraphRewriter for local_add_mul_fusion * Do not register optional rewrites if tensor__local_elemwise_fusion flag is disabled
上级 6c582e5b
......@@ -427,7 +427,7 @@ class SequenceDB(RewriteDatabase):
position_cutoff = tags[0].position_cutoff
# The RewriteDatabaseQuery instance might contain extra rewrites which need
# to be added the the sequence of rewrites (don't alter the
# to be added to the sequence of rewrites (don't alter the
# original dictionary)
if len(tags[0].extra_rewrites) > 0:
position_dict = position_dict.copy()
......
......@@ -13,6 +13,7 @@ from pytensor.graph.basic import Apply, Constant, io_toposort
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.op import compute_test_value, get_test_value
from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
copy_stack_trace,
in2out,
......@@ -529,6 +530,60 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
return rval
@node_rewriter([Elemwise])
def local_add_mul_fusion(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs.
It is better to fuse add/mul that way then in a Composite node as
this make the inner graph of the Composite smaller. This allows to
put more computation in a Composite before hitting the max
recursion limit when pickling Composite.
This rewrite is almost useless after the AlgebraicCanonizer is used,
but it catches a few edge cases that are not canonicalized by it
"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, (aes.Add, aes.Mul)
):
return False
s_op = node.op.scalar_op.__class__
new_inp = []
fused = False
nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs:
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, s_op)
and
# Do not duplicate the operation.
len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
):
new_inp.extend(inp.owner.inputs)
fused = True
else:
new_inp.append(inp)
# We can not compare the number of inputs as Mul and Add could have
# 0 or 1 inputs in some corner cases.
if fused:
output = node.op(*new_inp)
copy_stack_trace(node.outputs[0], output)
# Do the recursion here to help lower the number of
# FusionOptimizer iteration.
if output.owner:
output2 = local_add_mul_fusion.transform(fgraph, output.owner)
if output2:
return output2
return [output]
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
r"""Create a recursive function that fuses `Elemwise` `Op`\s.
......@@ -901,6 +956,13 @@ class FusionOptimizer(GraphRewriter):
if config.tensor__local_elemwise_fusion:
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
......@@ -917,15 +979,6 @@ if config.tensor__local_elemwise_fusion:
"FusionOptimizer",
position=49,
)
else:
compile.optdb.register( # type: ignore
"elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
@register_canonicalize
......
......@@ -92,7 +92,6 @@ from pytensor.tensor.rewriting.basic import (
register_uncanonicalize,
register_useless,
)
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
......@@ -2966,66 +2965,6 @@ def local_grad_log_erfc_neg(fgraph, node):
return [ret]
def local_add_mul_fusion(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs.
It is better to fuse add/mul that way then in a Composite node as
this make the inner graph of the Composite smaller. This allow to
put more computation in a Composite before hitting the max
recursion limit when pickling Composite.
"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, (aes.Add, aes.Mul)
):
return False
s_op = node.op.scalar_op.__class__
new_inp = []
fused = False
nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs:
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, s_op)
and
# Do not duplicate the operation.
len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
):
new_inp.extend(inp.owner.inputs)
fused = True
else:
new_inp.append(inp)
# We can not compare the number of inputs as Mul and Add could have
# 0 or 1 inputs in some corner cases.
if fused:
output = node.op(*new_inp)
copy_stack_trace(node.outputs[0], output)
# Do the recursion here to help lower the number of
# FusionOptimizer iteration.
if output.owner:
output2 = local_add_mul_fusion(fgraph, output.owner)
if output2:
return output2
return [output]
fuse_seqopt.register(
"local_add_mul_fusion",
FusionOptimizer(local_add_mul_fusion),
"fast_run",
"fusion",
position=0,
)
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
......
......@@ -4,9 +4,9 @@ import numpy as np
import pytest
import pytensor
import pytensor.scalar as aes
import pytensor.tensor as at
from pytensor import scalar as aes
from pytensor import shared
from pytensor import tensor as at
from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
......@@ -263,9 +263,8 @@ def test_local_useless_dimshuffle_in_reshape():
class TestFusion:
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"fusion",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
......@@ -1007,22 +1006,10 @@ class TestFusion:
)
def test_add_mul_fusion_inplace(self):
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, rewrites)
x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z
f = function([x, y, z], out, mode=mode)
f = function([x, y, z], out, mode=self.mode)
topo = [n for n in f.maker.fgraph.toposort()]
assert len(topo) == 2
assert topo[-1].op.inplace_pattern
......@@ -1050,8 +1037,7 @@ class TestFusion:
mode = Mode(linker="cvm")
mode._optimizer = mode._optimizer.including(
"local_elemwise_fusion",
"composite_elemwise_fusion",
"fusion",
"canonicalize",
"inplace",
)
......@@ -1073,18 +1059,6 @@ class TestFusion:
are checked.
"""
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, rewrites)
x, y, z = dmatrices("xyz")
x.tag.test_value = test_value
......@@ -1101,7 +1075,7 @@ class TestFusion:
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=mode)
f = function([x, y, z], out, mode=self.mode)
if test_value.size != 0:
# Confirm that the fusion happened
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论