提交 5d8ceb75 authored 作者: Frederic's avatar Frederic

Make PatternSub simpler.

上级 e532ac67
...@@ -1047,10 +1047,6 @@ class PatternSub(LocalOptimizer): ...@@ -1047,10 +1047,6 @@ class PatternSub(LocalOptimizer):
self.__name__ = name self.__name__ = name
self.pdb = pdb self.pdb = pdb
def skip_identities(self, expr):
if self.skip_identities_fn:
return self.skip_identities_fn(expr)
def op_key(self): def op_key(self):
return self.op return self.op
...@@ -1064,10 +1060,13 @@ class PatternSub(LocalOptimizer): ...@@ -1064,10 +1060,13 @@ class PatternSub(LocalOptimizer):
""" """
if node.op != self.op: if node.op != self.op:
return False return False
#TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False): def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
#TODO move outside match
def retry_with_equiv(): def retry_with_equiv():
expr_equiv = self.skip_identities(expr) if not self.skip_identities_fn:
return False
expr_equiv = self.skip_identities_fn(expr)
if expr_equiv is None: if expr_equiv is None:
return False return False
#TODO: Not sure how to handle multiple_clients flag #TODO: Not sure how to handle multiple_clients flag
...@@ -1126,19 +1125,19 @@ class PatternSub(LocalOptimizer): ...@@ -1126,19 +1125,19 @@ class PatternSub(LocalOptimizer):
pdb.set_trace() pdb.set_trace()
return u return u
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args)
elif isinstance(pattern, basestring):
return u[unify.Var(pattern)]
elif isinstance(pattern, (int, float)):
return pattern
else:
return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True, u = match(self.in_pattern, node.out, unify.Unification(), True,
self.pdb) self.pdb)
if u: if u:
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args)
elif isinstance(pattern, basestring):
return u[unify.Var(pattern)]
elif isinstance(pattern, (int, float)):
return pattern
else:
return pattern.clone()
p = self.out_pattern p = self.out_pattern
new = build(p, u) new = build(p, u)
####print "PatternSub matched:", new ####print "PatternSub matched:", new
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论