提交 fff84f9f authored 作者: Frederic's avatar Frederic

Make NodeFinder hashable.

This is needed for the new profiling of features.
上级 0b528507
...@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator): ...@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator):
raise ReplacementDidntRemovedError() raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper): class NodeFinder(Bookkeeper):
def __init__(self): def __init__(self):
self.fgraph = None self.fgraph = None
self.d = {}
def on_attach(self, fgraph): def on_attach(self, fgraph):
if self.fgraph is not None: if self.fgraph is not None:
...@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
try: try:
self.setdefault(node.op, []).append(node) self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
return return
except Exception, e: except Exception, e:
...@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper): ...@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper):
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
try: try:
nodes = self[node.op] nodes = self.d[node.op]
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
return return
nodes.remove(node) nodes.remove(node)
if not nodes: if not nodes:
del self[node.op] del self.d[node.op]
def query(self, fgraph, op): def query(self, fgraph, op):
try: try:
all = self.get(op, []) all = self.d.get(op, [])
except TypeError: except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the" raise TypeError("%s in unhashable and cannot be queried by the"
" optimizer" % op) " optimizer" % op)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论