提交 ec45e25f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not reject PatternNodeRewriter due unrelated multiple clients

上级 2143d851
......@@ -1616,14 +1616,6 @@ class PatternNodeRewriter(NodeRewriter):
from etuples.core import ExpressionTuple
from unification import reify, unify
# TODO: We shouldn't need to iterate like this.
if not self.allow_multiple_clients and any(
len(fgraph.clients.get(v)) > 1
for v in vars_between(fgraph.inputs, node.outputs)
if v not in fgraph.inputs
):
return False
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(fgraph, node):
if real_node == "output":
......@@ -1648,6 +1640,15 @@ class PatternNodeRewriter(NodeRewriter):
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
if not self.allow_multiple_clients:
input_vars = list(s.values())
if any(
len(fgraph.clients[v]) > 1
for v in vars_between(input_vars, node.inputs)
if v not in input_vars
):
return False
if ret.owner:
if not (
len(node.outputs) == len(ret.owner.outputs)
......
......@@ -50,8 +50,11 @@ class AssertNoChanges(Feature):
raise AssertionError()
def OpKeyPatternNodeRewriter(p1, p2, ign=False):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def OpKeyPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False):
return OpKeyGraphRewriter(
PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients),
ignore_newtrees=ign,
)
def WalkingPatternNodeRewriter(p1, p2, ign=True):
......@@ -207,13 +210,70 @@ class TestPatternNodeRewriter:
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_allow_multiple_clients(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e0 = op1(x, y)
# `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(e0), e0)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op4, (op1, "x", "y")), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
x, y, z = inputs = MyVariable("x"), MyVariable("y"), MyVariable("z")
w = op1(x, y)
# `w` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(w), w)
# By default, allow_multiple_clients is False
# So the replacement should fail
outputs = [e]
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op1, "x", "y")),
(op3, "x", "y"),
).rewrite(g)
assert equal_computations(g.outputs, outputs)
# Now it should be fine
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op1, "x", "y")),
(op3, "x", "y"),
allow_multiple_clients=True,
).rewrite(g)
assert equal_computations(g.outputs, [op3(op3(x, y), w)])
# The fact that the inputs of the pattern have multiple clients should not matter
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op3, (op4, "w"), "w"),
(op3, "w", "w"),
allow_multiple_clients=False,
).rewrite(g)
assert equal_computations(g.outputs, [op3(w, w)])
# The fact that are multiple clients above the inputs of the pattern should not matter
v = op4(e)
e1 = op4(v)
e2 = op1(x, x) # Irrelevant reuse of x that should not block rewrite either
e3 = op1(v, v) # Relevant reuse of v that should block rewrite
outputs = [e1, e2]
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=False,
).rewrite(g)
assert equal_computations(g.outputs, [e, e2])
outputs = [e1, e3]
g = FunctionGraph([x, y, z], outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=False,
).rewrite(g)
assert equal_computations(g.outputs, outputs)
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=True,
).rewrite(g)
assert equal_computations(g.outputs, [e, e3])
def test_eq(self):
# replacing the whole graph
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论