提交 0cf98513 authored 作者: sentient07's avatar sentient07

Fixed a bug in multiple optimization

上级 aab78e7e
...@@ -1292,9 +1292,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1292,9 +1292,9 @@ class LocalOptGroup(LocalOptimizer):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
def apply_mult_opts(node, fgraph, multiple_opts=False): def apply_mult_opts(node, fgraph, multiple_opts=False, prev_repl=None):
repl = False
opts = self.track_map[type(node.op)] + self.track_map[node.op] + self.track_map[None] opts = self.track_map[type(node.op)] + self.track_map[node.op] + self.track_map[None]
repl = prev_repl
for opt in opts: for opt in opts:
opt_start = time.time() opt_start = time.time()
repl = opt.transform(node) repl = opt.transform(node)
...@@ -1312,8 +1312,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1312,8 +1312,9 @@ class LocalOptGroup(LocalOptimizer):
# Ensuring not the input of graph # Ensuring not the input of graph
assert repl[0].owner assert repl[0].owner
new_node = repl[0].owner new_node = repl[0].owner
apply_mult_opts(new_node, fgraph, True) apply_mult_opts(new_node, fgraph, True, repl)
return repl return repl
node_start = time.time() node_start = time.time()
new_var = apply_mult_opts(node, node.fgraph, self.apply_all_opts) new_var = apply_mult_opts(node, node.fgraph, self.apply_all_opts)
node_finish = time.time() node_finish = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论