提交 9ed13246 authored 作者: sentient07's avatar sentient07

Applying Multiple optimization to a node

上级 995c9c19
...@@ -1237,7 +1237,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1237,7 +1237,7 @@ class LocalOptGroup(LocalOptimizer):
""" """
def __init__(self, *optimizers): def __init__(self, apply_all_opts=False, *optimizers):
if len(optimizers) == 1 and isinstance(optimizers[0], list): if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB. # This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0]) optimizers = tuple(optimizers[0])
...@@ -1246,6 +1246,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1246,6 +1246,7 @@ class LocalOptGroup(LocalOptimizer):
for opt in optimizers) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers) for opt in optimizers)
self.apply_all_opts = apply_all_opts
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
...@@ -1261,11 +1262,17 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1261,11 +1262,17 @@ class LocalOptGroup(LocalOptimizer):
return t return t
def transform(self, node): def transform(self, node):
repl = None
for opt in self.opts: for opt in self.opts:
repl = opt.transform(node) repl = opt.transform(node)
if repl: if repl:
if self.apply_all_opts is True:
node.outputs = repl
continue
return repl return repl
return repl
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream) (' ' * level), self.__class__.__name__, id(self)), file=stream)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论