提交 bb9042a7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix pep8 and simplify the code as we don't support one option.

上级 96f7cdf1
...@@ -1385,8 +1385,14 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1385,8 +1385,14 @@ class LocalOptGroup(LocalOptimizer):
class GraphToGPULocalOptGroup(LocalOptGroup): class GraphToGPULocalOptGroup(LocalOptGroup):
""" """This is the equivalent of LocalOptGroup for GraphToGPU.
This is the equivalent of LocalOptGroup for GraphToGPU
The main different is the function signature of the local
optimizer that use the GraphToGPU signature and not the normal
LocalOptimizer signature.
apply_all_opts=True is not supported
""" """
def __init__(self, *optimizers, **kwargs): def __init__(self, *optimizers, **kwargs):
super(GraphToGPULocalOptGroup, self).__init__(*optimizers, **kwargs) super(GraphToGPULocalOptGroup, self).__init__(*optimizers, **kwargs)
...@@ -1396,34 +1402,21 @@ class GraphToGPULocalOptGroup(LocalOptGroup): ...@@ -1396,34 +1402,21 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
fgraph = outputs[0].fgraph fgraph = outputs[0].fgraph
repl = None opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None]
while True: for opt in opts:
opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None] opt_start = time.time()
new_repl = None new_repl = opt.transform(op, context_name, inputs, outputs)
for opt in opts: opt_finish = time.time()
opt_start = time.time() if self.profile:
new_repl = opt.transform(op, context_name, inputs, outputs) self.time_opts[opt] += opt_start - opt_finish
opt_finish = time.time() self.process_count[opt] += 1
if self.profile: if not new_repl:
self.time_opts[opt] += opt_start - opt_finish continue
self.process_count[opt] += 1 if self.profile:
if not new_repl: self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl))
continue self.applied_true[opt] += 1
else:
if self.profile: return new_repl
self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl))
self.applied_true[opt] += 1
break # break from the for loop over optimization.
if not new_repl: # No optimization applied in the last iteration
return repl
# only 1 iteration or we are at the start of the graph.
if not self.apply_all_opts or not new_repl[0].owner:
return new_repl
if len(new_repl) > 1:
s = set([v.owner for v in new_repl])
assert len(s) == 1
repl = new_repl
node = repl[0].owner
class OpSub(LocalOptimizer): class OpSub(LocalOptimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论