提交 54354b10 authored 作者: sentient07's avatar sentient07

Changed normal dict to ordered dict

上级 fc4b41a7
...@@ -5,7 +5,7 @@ amount of useful generic optimization tools. ...@@ -5,7 +5,7 @@ amount of useful generic optimization tools.
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from collections import deque from collections import deque, defaultdict
import copy import copy
import inspect import inspect
import logging import logging
...@@ -1264,7 +1264,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1264,7 +1264,7 @@ class LocalOptGroup(LocalOptimizer):
for opt in optimizers) for opt in optimizers)
self.apply_all_opts = kwargs.pop('apply_all_opts', False) self.apply_all_opts = kwargs.pop('apply_all_opts', False)
self.track_map = OrderedDict() self.track_map = defaultdict(lambda: [])
assert len(kwargs) == 0 assert len(kwargs) == 0
self.time_opts = {} self.time_opts = {}
self.time_nodes = {} self.time_nodes = {}
...@@ -1281,7 +1281,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1281,7 +1281,7 @@ class LocalOptGroup(LocalOptimizer):
self.node_created.setdefault(o, 0) self.node_created.setdefault(o, 0)
for c in o.tracks(): for c in o.tracks():
self.track_map.setdefault(c, []).append(o) self.track_map[c].append(o)
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
...@@ -1300,9 +1300,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1300,9 +1300,9 @@ class LocalOptGroup(LocalOptimizer):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
def apply_mult_opts(node, fgraph=False, multiple_opts=False): def apply_mult_opts(node, fgraph, multiple_opts=False):
repl = False repl = False
opts = self.track_map.get(type(node.op), []) + self.track_map.get(node.op, []) + self.track_map.get(None, []) opts = self.track_map[type(node.op)] + self.track_map[node.op] + self.track_map[None]
for opt in opts: for opt in opts:
opt_start = time.time() opt_start = time.time()
...@@ -1323,7 +1323,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1323,7 +1323,7 @@ class LocalOptGroup(LocalOptimizer):
new_node = repl[0].owner new_node = repl[0].owner
if hasattr(new_node, 'fgraph'): if hasattr(new_node, 'fgraph'):
apply_mult_opts(new_node, new_node.fgraph, True) apply_mult_opts(new_node, new_node.fgraph, True)
apply_mult_opts(new_node, False, True) apply_mult_opts(new_node, fgraph, True)
return repl return repl
node_start = time.time() node_start = time.time()
new_var = apply_mult_opts(node, node.fgraph, self.apply_all_opts) new_var = apply_mult_opts(node, node.fgraph, self.apply_all_opts)
......
...@@ -322,7 +322,7 @@ class SequenceDB(DB): ...@@ -322,7 +322,7 @@ class SequenceDB(DB):
def register(self, name, obj, position, *tags): def register(self, name, obj, position, *tags):
super(SequenceDB, self).register(name, obj, *tags) super(SequenceDB, self).register(name, obj, *tags)
if position == 'last': if position == 'last':
self.position[name] = max(self.position.values()) self.__position__[name] = max(self.__position__.values())
else: else:
assert isinstance(position, (integer_types, float)) assert isinstance(position, (integer_types, float))
self.__position__[name] = position self.__position__[name] = position
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论