提交 085b71c8 authored 作者: Frederic Bastien's avatar Frederic Bastien

Convert recursion to while loop. Simplify and rename the code at the same time.

上级 a108477d
...@@ -1299,37 +1299,33 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1299,37 +1299,33 @@ class LocalOptGroup(LocalOptimizer):
def transform(self, node): def transform(self, node):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
fgraph = node.fgraph
def apply_mult_opts(node, fgraph, multiple_opts=False, prev_repl=None): repl = None
while True:
opts = self.track_map[type(node.op)] + self.track_map[node.op] + self.track_map[None] opts = self.track_map[type(node.op)] + self.track_map[node.op] + self.track_map[None]
repl = prev_repl
new_repl = None new_repl = None
for opt in opts: for opt in opts:
opt_start = time.time() opt_start = time.time()
repl = opt.transform(node) new_repl = opt.transform(node)
opt_finish = time.time() opt_finish = time.time()
if self.profile: if self.profile:
self.time_opts[opt] += opt_start - opt_finish self.time_opts[opt] += opt_start - opt_finish
self.process_count[opt] += 1 self.process_count[opt] += 1
if not repl: if not new_repl:
continue continue
else: else:
assert len(repl) == 1 assert len(new_repl) == 1
if self.profile: if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, repl)) self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl))
self.applied_true[opt] += 1 self.applied_true[opt] += 1
if not multiple_opts or not repl[0].owner: break # break from the for loop over optimization.
return repl if not new_repl: # No optimization applied in the last iteration
# Ensuring not the input of graph return repl
new_node = repl[0].owner # only 1 iteration or we are at the start of the graph.
new_repl = apply_mult_opts(new_node, fgraph, True, repl) if not self.apply_all_opts or not new_repl[0].owner:
if new_repl: return new_repl
repl = new_repl repl = new_repl
return repl node = repl[0].owner
new_var = apply_mult_opts(node, node.fgraph, self.apply_all_opts)
return new_var
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论