提交 da86d351 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Fix bug where PatternSub could incorrectly replace nodes with different number of outputs

Bug was introduced in b920bd29
上级 a0604e52
...@@ -1762,9 +1762,12 @@ class PatternSub(LocalOptimizer): ...@@ -1762,9 +1762,12 @@ class PatternSub(LocalOptimizer):
ret.tag.values_eq_approx = self.values_eq_approx ret.tag.values_eq_approx = self.values_eq_approx
if ret.owner: if ret.owner:
if not all( if not (
o.type.is_super(new_o.type) len(node.outputs) == len(ret.owner.outputs)
for o, new_o in zip(node.outputs, ret.owner.outputs) and all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs)
)
): ):
return False return False
else: else:
......
...@@ -14,6 +14,7 @@ from aesara.graph.opt import ( ...@@ -14,6 +14,7 @@ from aesara.graph.opt import (
OpSub, OpSub,
PatternSub, PatternSub,
TopoOptimizer, TopoOptimizer,
in2out,
local_optimizer, local_optimizer,
logging, logging,
pre_constant_merge, pre_constant_merge,
...@@ -36,6 +37,7 @@ from tests.graph.utils import ( ...@@ -36,6 +37,7 @@ from tests.graph.utils import (
op5, op5,
op6, op6,
op_cast_type2, op_cast_type2,
op_multiple_outputs,
op_y, op_y,
op_z, op_z,
) )
...@@ -652,6 +654,24 @@ def test_patternsub_invalid_dtype(out_pattern): ...@@ -652,6 +654,24 @@ def test_patternsub_invalid_dtype(out_pattern):
assert e.type.is_super(fg.outputs[0].type) assert e.type.is_super(fg.outputs[0].type)
def test_patternsub_different_output_lengths():
# Test that PatternSub won't replace nodes with different numbers of outputs
ps = PatternSub(
(op1, "x"),
("x"),
name="ps",
)
opt = in2out(ps)
x = MyVariable("x")
e1, e2 = op_multiple_outputs(x)
o = op1(e1)
fgraph = FunctionGraph(inputs=[x], outputs=[o])
opt.optimize(fgraph)
assert fgraph.outputs[0].owner.op == op1
class TestLocalOptGroup: class TestLocalOptGroup:
def test_optimizer_verbose(self, capsys): def test_optimizer_verbose(self, capsys):
......
...@@ -96,6 +96,12 @@ class MyOpCastType2(MyOp): ...@@ -96,6 +96,12 @@ class MyOpCastType2(MyOp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
class MyOpMultipleOutputs(MyOp):
def make_node(self, input):
outputs = [input.type(), input.type()]
return Apply(self, [input], outputs)
op1 = MyOp("Op1") op1 = MyOp("Op1")
op2 = MyOp("Op2") op2 = MyOp("Op2")
op3 = MyOp("Op3") op3 = MyOp("Op3")
...@@ -108,6 +114,7 @@ op_y = MyOp("OpY", x=1) ...@@ -108,6 +114,7 @@ op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1) op_z = MyOp("OpZ", x=1)
op_cast_type2 = MyOpCastType2("OpCastType2") op_cast_type2 = MyOpCastType2("OpCastType2")
op_multiple_outputs = MyOpMultipleOutputs("OpMultipleOutputs")
class MyInnerGraphOp(Op, HasInnerGraph): class MyInnerGraphOp(Op, HasInnerGraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论