提交 c158857e authored 作者: Frederic Bastien's avatar Frederic Bastien

Speed up add to fgraph when there is many clients to nodes

上级 8b008d5e
......@@ -260,7 +260,7 @@ class FunctionGraph(utils.object2):
"""
return r.clients
def __add_clients__(self, r, new_clients):
def __add_client__(self, r, new_client, check=True):
"""
Updates the list of clients of r with new_clients.
......@@ -268,18 +268,19 @@ class FunctionGraph(utils.object2):
----------
r
Variable.
new_clients
List of (node, i) pairs such that node.inputs[i] is r.
new_client
(node, i) pairs such that node.inputs[i] is r.
"""
if set(r.clients).intersection(set(new_clients)):
test = new_client in r.clients
if test:
print('ERROR: clients intersect!', file=sys.stderr)
print(' RCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in r.clients], file=sys.stderr)
print(' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients], file=sys.stderr)
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
for n, i in [new_client]], file=sys.stderr)
assert not test
r.clients.append(new_client)
def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None):
......@@ -431,7 +432,7 @@ class FunctionGraph(utils.object2):
if input not in self.variables:
self.__setup_r__(input)
self.variables.add(input)
self.__add_clients__(input, [(node, i)])
self.__add_client__(input, (node, i))
assert node.fgraph is self
self.execute_callbacks('on_import', node, reason)
......@@ -470,7 +471,7 @@ class FunctionGraph(utils.object2):
return
self.__import_r__(new_r, reason=reason)
self.__add_clients__(new_r, [(node, i)])
self.__add_client__(new_r, (node, i))
prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论