提交 7a6c444d authored 作者: Olivier Breuleux's avatar Olivier Breuleux

more tests for PatternOptimizer, fixed bug

上级 2514d77e
......@@ -201,6 +201,44 @@ class _test_PatternOptimizer(unittest.TestCase):
'constraint': constraint}),
(Op3, '1')).optimize(g)
assert str(g) == "[Op4(Op3(Op2(x, y)), Op1(Op1(x, y)))]"
def test_match_same(self):
x, y, z = inputs()
e = op1(x, x)
g = env([x, y, z], [e])
PatternOptimizer((Op1, 'x', 'y'),
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, x)]"
def test_match_same_illegal(self):
x, y, z = inputs()
e = op2(op1(x, x), op1(x, y))
g = env([x, y, z], [e])
def constraint(env, r):
# Only replacing if the input is an instance of Op2
return r.owner.inputs[0] is not r.owner.inputs[1]
PatternOptimizer({'pattern': (Op1, 'x', 'y'),
'constraint': constraint},
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op2(Op1(x, x), Op3(x, y))]"
def test_multi(self):
x, y, z = inputs()
e0 = op1(x, y)
e = op3(op4(e0), e0)
g = env([x, y, z], [e])
PatternOptimizer((Op4, (Op1, 'x', 'y')),
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]"
def test_multi_ingraph(self):
x, y, z = inputs()
e0 = op1(x, y)
e = op4(e0, e0)
g = env([x, y, z], [e])
PatternOptimizer((Op4, (Op1, 'x', 'y'), (Op1, 'x', 'y')),
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, y)]"
class _test_PatternDescOptimizer(unittest.TestCase):
......
......@@ -284,7 +284,7 @@ class PatternOptimizer(OpSpecificOptimizer):
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
if not issubclass(expr.owner.__class__, pattern[0]) or (self.allow_multiple_clients and not first and env.nclients(expr.owner) > 1):
if not issubclass(expr.owner.__class__, pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
......@@ -393,7 +393,7 @@ class PatternDescOptimizer(LocalOptimizer):
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
if not expr.owner.desc() == pattern[0] or (self.allow_multiple_clients and not first and env.nclients(expr.owner) > 1):
if not expr.owner or not expr.owner.desc() == pattern[0] or (self.allow_multiple_clients and not first and env.nclients(expr.owner) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论