提交 8c94ba0c authored 作者: James Bergstra's avatar James Bergstra

added skip_identities code to PatternSub

上级 a7e855b6
......@@ -525,7 +525,8 @@ class PatternSub(LocalOptimizer):
(scrabble, 'x'))
"""
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False):
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False,
skip_identities_fn=None):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
......@@ -543,7 +544,12 @@ class PatternSub(LocalOptimizer):
raise TypeError("The pattern to search for must start with a specific Op instance.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn
def skip_identities(self, expr):
if self.skip_identities_fn:
return self.skip_identities_fn(expr)
def op_key(self):
return self.op
......@@ -568,13 +574,22 @@ class PatternSub(LocalOptimizer):
if node.op != self.op:
return False
def match(pattern, expr, u, allow_multiple_clients = False):
def retry_with_equiv():
expr_equiv = self.skip_identities(expr)
if expr_equiv is None:
return False
#TODO: Not sure how to handle multiple_clients flag
###print 'retrying match', pattern, expr_equiv
return match(pattern, expr_equiv, u,
allow_multiple_clients=allow_multiple_clients)
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
return False
return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs):
return False
return retry_with_equiv()
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u, self.allow_multiple_clients)
if not u:
......@@ -588,17 +603,17 @@ class PatternSub(LocalOptimizer):
if constraint(expr):
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False))
else:
return False
return retry_with_equiv()
elif isinstance(pattern, str):
v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr:
return False
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
return u
else:
return False
return retry_with_equiv()
return u
def build(pattern, u):
......@@ -614,6 +629,7 @@ class PatternSub(LocalOptimizer):
if u:
p = self.out_pattern
new = build(p, u)
####print "PatternSub matched:", new
return [new]
else:
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论