提交 fff84f9f authored 作者: Frederic's avatar Frederic

Make NodeFinder hashable.

This is needed for the new profiling of features.
上级 0b528507
......@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator):
raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper):
class NodeFinder(Bookkeeper):
def __init__(self):
self.fgraph = None
self.d = {}
def on_attach(self, fgraph):
if self.fgraph is not None:
......@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper):
def on_import(self, fgraph, node, reason):
try:
self.setdefault(node.op, []).append(node)
self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
return
except Exception, e:
......@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper):
def on_prune(self, fgraph, node, reason):
try:
nodes = self[node.op]
nodes = self.d[node.op]
except TypeError: # node.op is unhashable
return
nodes.remove(node)
if not nodes:
del self[node.op]
del self.d[node.op]
def query(self, fgraph, op):
try:
all = self.get(op, [])
all = self.d.get(op, [])
except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the"
" optimizer" % op)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论