提交 8c94ba0c authored 作者: James Bergstra's avatar James Bergstra

added skip_identities code to PatternSub

上级 a7e855b6
...@@ -525,7 +525,8 @@ class PatternSub(LocalOptimizer): ...@@ -525,7 +525,8 @@ class PatternSub(LocalOptimizer):
(scrabble, 'x')) (scrabble, 'x'))
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False): def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False,
skip_identities_fn=None):
""" """
Creates a PatternSub that replaces occurrences of Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern. in_pattern by occurrences of out_pattern.
...@@ -543,7 +544,12 @@ class PatternSub(LocalOptimizer): ...@@ -543,7 +544,12 @@ class PatternSub(LocalOptimizer):
raise TypeError("The pattern to search for must start with a specific Op instance.") raise TypeError("The pattern to search for must start with a specific Op instance.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n" self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn
def skip_identities(self, expr):
if self.skip_identities_fn:
return self.skip_identities_fn(expr)
def op_key(self): def op_key(self):
return self.op return self.op
...@@ -568,13 +574,22 @@ class PatternSub(LocalOptimizer): ...@@ -568,13 +574,22 @@ class PatternSub(LocalOptimizer):
if node.op != self.op: if node.op != self.op:
return False return False
def match(pattern, expr, u, allow_multiple_clients = False): def match(pattern, expr, u, allow_multiple_clients = False):
def retry_with_equiv():
expr_equiv = self.skip_identities(expr)
if expr_equiv is None:
return False
#TODO: Not sure how to handle multiple_clients flag
###print 'retrying match', pattern, expr_equiv
return match(pattern, expr_equiv, u,
allow_multiple_clients=allow_multiple_clients)
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1): if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
return False return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return False return retry_with_equiv()
for p, v in zip(pattern[1:], expr.owner.inputs): for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u, self.allow_multiple_clients) u = match(p, v, u, self.allow_multiple_clients)
if not u: if not u:
...@@ -588,17 +603,17 @@ class PatternSub(LocalOptimizer): ...@@ -588,17 +603,17 @@ class PatternSub(LocalOptimizer):
if constraint(expr): if constraint(expr):
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False)) return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False))
else: else:
return False return retry_with_equiv()
elif isinstance(pattern, str): elif isinstance(pattern, str):
v = unify.Var(pattern) v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr: if u[v] is not v and u[v] is not expr:
return False return retry_with_equiv()
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr): elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
return u return u
else: else:
return False return retry_with_equiv()
return u return u
def build(pattern, u): def build(pattern, u):
...@@ -614,6 +629,7 @@ class PatternSub(LocalOptimizer): ...@@ -614,6 +629,7 @@ class PatternSub(LocalOptimizer):
if u: if u:
p = self.out_pattern p = self.out_pattern
new = build(p, u) new = build(p, u)
####print "PatternSub matched:", new
return [new] return [new]
else: else:
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论