提交 c158857e authored 作者: Frederic Bastien's avatar Frederic Bastien

Speed up add to fgraph when there is many clients to nodes

上级 8b008d5e
...@@ -260,7 +260,7 @@ class FunctionGraph(utils.object2): ...@@ -260,7 +260,7 @@ class FunctionGraph(utils.object2):
""" """
return r.clients return r.clients
def __add_clients__(self, r, new_clients): def __add_client__(self, r, new_client, check=True):
""" """
Updates the list of clients of r with new_clients. Updates the list of clients of r with new_clients.
...@@ -268,18 +268,19 @@ class FunctionGraph(utils.object2): ...@@ -268,18 +268,19 @@ class FunctionGraph(utils.object2):
---------- ----------
r r
Variable. Variable.
new_clients new_client
List of (node, i) pairs such that node.inputs[i] is r. (node, i) pairs such that node.inputs[i] is r.
""" """
if set(r.clients).intersection(set(new_clients)): test = new_client in r.clients
if test:
print('ERROR: clients intersect!', file=sys.stderr) print('ERROR: clients intersect!', file=sys.stderr)
print(' RCLIENTS of', r, [(n, i, type(n), id(n)) print(' RCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in r.clients], file=sys.stderr) for n, i in r.clients], file=sys.stderr)
print(' NCLIENTS of', r, [(n, i, type(n), id(n)) print(' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients], file=sys.stderr) for n, i in [new_client]], file=sys.stderr)
assert not set(r.clients).intersection(set(new_clients)) assert not test
r.clients += new_clients r.clients.append(new_client)
def __remove_clients__(self, r, clients_to_remove, def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None): prune=True, reason=None):
...@@ -431,7 +432,7 @@ class FunctionGraph(utils.object2): ...@@ -431,7 +432,7 @@ class FunctionGraph(utils.object2):
if input not in self.variables: if input not in self.variables:
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
self.__add_clients__(input, [(node, i)]) self.__add_client__(input, (node, i))
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node, reason) self.execute_callbacks('on_import', node, reason)
...@@ -470,7 +471,7 @@ class FunctionGraph(utils.object2): ...@@ -470,7 +471,7 @@ class FunctionGraph(utils.object2):
return return
self.__import_r__(new_r, reason=reason) self.__import_r__(new_r, reason=reason)
self.__add_clients__(new_r, [(node, i)]) self.__add_client__(new_r, (node, i))
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the # However it may introduce cycles to the graph, in which case the
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论