提交 7db04c9c authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Do not allow `PatternSub` to return node with different type

上级 4a6a10b1
......@@ -1793,6 +1793,17 @@ class PatternSub(LocalOptimizer):
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
if ret.owner:
if [out.type for out in ret.owner.outputs] != [
out.type for out in node.outputs
]:
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
if ret.type != node.outputs[0].type:
return False
return [ret]
def __str__(self):
......
......@@ -31,6 +31,7 @@ from tests.graph.utils import (
op4,
op5,
op6,
op_cast_type2,
op_y,
op_z,
)
......@@ -677,3 +678,23 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
else:
assert isinstance(output, Constant)
assert not hasattr(output.tag, "value_eq_approx")
@pytest.mark.parametrize("out_pattern", [(op1, "x"), "x"])
def test_patternsub_invalid_dtype(out_pattern):
# PatternSub would wrongly return output of different dtype as the original node
x = MyVariable("x")
e = op_cast_type2(x)
fg = FunctionGraph([x], [e])
opt = EquilibriumOptimizer(
[
PatternSub(
(op_cast_type2, "x"),
out_pattern,
)
],
max_use_ratio=1,
)
opt.optimize(fg)
assert fg.apply_nodes.pop().op == op_cast_type2
......@@ -85,6 +85,17 @@ class MyOp(Op):
return id(self)
class MyOpCastType2(MyOp):
def make_node(self, *inputs):
inputs = list(map(is_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType2()()]
return Apply(self, inputs, outputs)
op1 = MyOp("Op1")
op2 = MyOp("Op2")
op3 = MyOp("Op3")
......@@ -95,3 +106,5 @@ op_d = MyOp("OpD", {0: [0]})
op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1)
op_cast_type2 = MyOpCastType2("OpCastType2")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论