提交 4ed05072 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Adam Becker

Allow useless opt to return a dict.

上级 60a88f29
...@@ -1366,21 +1366,27 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1366,21 +1366,27 @@ class LocalOptGroup(LocalOptimizer):
self.process_count[opt] += 1 self.process_count[opt] += 1
if not new_repl: if not new_repl:
continue continue
else: if isinstance(new_repl, (tuple, list)):
new_vars = new_repl
else: # It must be a dict
new_vars = new_repl.values()
if self.profile: if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl)) self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars))
self.applied_true[opt] += 1 self.applied_true[opt] += 1
break # break from the for loop over optimization. break # break from the for loop over optimization.
if not new_repl: # No optimization applied in the last iteration if not new_repl: # No optimization applied in the last iteration
return repl return repl
# only 1 iteration or we are at the start of the graph. # only 1 iteration
if not self.apply_all_opts or not new_repl[0].owner: if not self.apply_all_opts:
return new_repl
if not new_vars[0].owner:
# We are at the start of the graph.
return new_repl return new_repl
if len(new_repl) > 1: if len(new_repl) > 1:
s = set([v.owner for v in new_repl]) s = set([v.owner for v in new_repl])
assert len(s) == 1 assert len(s) == 1
repl = new_repl repl = new_repl
node = repl[0].owner node = new_vars[0].owner
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
......
...@@ -7583,6 +7583,5 @@ def local_useless_topk(node): ...@@ -7583,6 +7583,5 @@ def local_useless_topk(node):
axis=op.axis, axis=op.axis,
idx_dtype=op.idx_dtype, idx_dtype=op.idx_dtype,
return_values=ret_val, return_values=ret_val,
return_indices=ret_idx)(x, k)[0] return_indices=ret_idx)(x, k)
return {old_output:new_output} return {old_output:new_output}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论