提交 da86d351 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Fix bug where PatternSub could incorrectly replace nodes with different number of outputs

Bug was introduced in b920bd29
上级 a0604e52
......@@ -1762,9 +1762,12 @@ class PatternSub(LocalOptimizer):
ret.tag.values_eq_approx = self.values_eq_approx
if ret.owner:
if not all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs)
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs)
)
):
return False
else:
......
......@@ -14,6 +14,7 @@ from aesara.graph.opt import (
OpSub,
PatternSub,
TopoOptimizer,
in2out,
local_optimizer,
logging,
pre_constant_merge,
......@@ -36,6 +37,7 @@ from tests.graph.utils import (
op5,
op6,
op_cast_type2,
op_multiple_outputs,
op_y,
op_z,
)
......@@ -652,6 +654,24 @@ def test_patternsub_invalid_dtype(out_pattern):
assert e.type.is_super(fg.outputs[0].type)
def test_patternsub_different_output_lengths():
# Test that PatternSub won't replace nodes with different numbers of outputs
ps = PatternSub(
(op1, "x"),
("x"),
name="ps",
)
opt = in2out(ps)
x = MyVariable("x")
e1, e2 = op_multiple_outputs(x)
o = op1(e1)
fgraph = FunctionGraph(inputs=[x], outputs=[o])
opt.optimize(fgraph)
assert fgraph.outputs[0].owner.op == op1
class TestLocalOptGroup:
def test_optimizer_verbose(self, capsys):
......
......@@ -96,6 +96,12 @@ class MyOpCastType2(MyOp):
return Apply(self, inputs, outputs)
class MyOpMultipleOutputs(MyOp):
def make_node(self, input):
outputs = [input.type(), input.type()]
return Apply(self, [input], outputs)
op1 = MyOp("Op1")
op2 = MyOp("Op2")
op3 = MyOp("Op3")
......@@ -108,6 +114,7 @@ op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1)
op_cast_type2 = MyOpCastType2("OpCastType2")
op_multiple_outputs = MyOpMultipleOutputs("OpMultipleOutputs")
class MyInnerGraphOp(Op, HasInnerGraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论