提交 b65faccd authored 作者: Frederic Bastien's avatar Frederic Bastien

add a pdb parameter to PatternSub to ease debugging.

上级 e3f79f19
...@@ -558,13 +558,17 @@ class PatternSub(LocalOptimizer): ...@@ -558,13 +558,17 @@ class PatternSub(LocalOptimizer):
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False,
skip_identities_fn = None, name = None): skip_identities_fn = None, name = None, pdb = False):
""" """
Creates a PatternSub that replaces occurrences of Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern. in_pattern by occurrences of out_pattern.
If allow_multiple_clients is False, the pattern matching will :param in_pattern: the input pattern that we want to replace
fail if one of the subpatterns has more than one client. :param out_pattern: the replacement pattern
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match.
""" """
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
...@@ -579,6 +583,7 @@ class PatternSub(LocalOptimizer): ...@@ -579,6 +583,7 @@ class PatternSub(LocalOptimizer):
self.skip_identities_fn = skip_identities_fn self.skip_identities_fn = skip_identities_fn
if name: if name:
self.__name__ = name self.__name__ = name
self.pdb = pdb
def skip_identities(self, expr): def skip_identities(self, expr):
if self.skip_identities_fn: if self.skip_identities_fn:
...@@ -606,7 +611,8 @@ class PatternSub(LocalOptimizer): ...@@ -606,7 +611,8 @@ class PatternSub(LocalOptimizer):
""" """
if node.op != self.op: if node.op != self.op:
return False return False
def match(pattern, expr, u, allow_multiple_clients = False):
def match(pattern, expr, u, allow_multiple_clients = False, pdb = False):
def retry_with_equiv(): def retry_with_equiv():
expr_equiv = self.skip_identities(expr) expr_equiv = self.skip_identities(expr)
if expr_equiv is None: if expr_equiv is None:
...@@ -652,6 +658,8 @@ class PatternSub(LocalOptimizer): ...@@ -652,6 +658,8 @@ class PatternSub(LocalOptimizer):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
if pdb:
import pdb;pdb.set_trace()
return u return u
def build(pattern, u): def build(pattern, u):
...@@ -664,8 +672,7 @@ class PatternSub(LocalOptimizer): ...@@ -664,8 +672,7 @@ class PatternSub(LocalOptimizer):
return pattern return pattern
else: else:
return pattern.clone() return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb)
u = match(self.in_pattern, node.out, unify.Unification(), True)
if u: if u:
p = self.out_pattern p = self.out_pattern
new = build(p, u) new = build(p, u)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论