提交 bc6770cc authored 作者: Ramana.S's avatar Ramana.S

initial commit

上级 9627228c
...@@ -153,6 +153,7 @@ class FunctionGraph(utils.object2): ...@@ -153,6 +153,7 @@ class FunctionGraph(utils.object2):
self.inputs = list(inputs) self.inputs = list(inputs)
self.outputs = outputs self.outputs = outputs
self._removed_nodes = set()
for f in features: for f in features:
self.attach_feature(f) self.attach_feature(f)
...@@ -320,13 +321,15 @@ class FunctionGraph(utils.object2): ...@@ -320,13 +321,15 @@ class FunctionGraph(utils.object2):
if output.clients or output in self.outputs] if output.clients or output in self.outputs]
# If the apply node is not used and is not an output # If the apply node is not used and is not an output
if not used_or_output: if not used_or_output:
self.apply_nodes.remove(apply_node) if apply_node in self.apply_nodes:
self.variables.difference_update(apply_node.outputs) #keeping track of removed apply node
self.execute_callbacks('on_prune', apply_node, reason) self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
for i, input in enumerate(apply_node.inputs): self.execute_callbacks('on_prune', apply_node, reason)
self.__remove_clients__(input, [(apply_node, i)], self._removed_nodes.add(apply_node)
reason=reason) for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)],
reason=reason)
# variable should not have any clients. # variable should not have any clients.
# assert not variable.clients # assert not variable.clients
...@@ -478,8 +481,14 @@ class FunctionGraph(utils.object2): ...@@ -478,8 +481,14 @@ class FunctionGraph(utils.object2):
for node in new_nodes: for node in new_nodes:
assert node not in self.apply_nodes assert node not in self.apply_nodes
self.__setup_node__(node) prevent_addition = False
self.apply_nodes.add(node) for n in self._removed_nodes :
if node is n :
prevent_addition = True
if not prevent_addition :
self.__setup_node__(node)
self.apply_nodes.add(node)
for output in node.outputs: for output in node.outputs:
self.__setup_r__(output) self.__setup_r__(output)
self.variables.add(output) self.variables.add(output)
...@@ -488,8 +497,9 @@ class FunctionGraph(utils.object2): ...@@ -488,8 +497,9 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
self.__add_clients__(input, [(node, i)]) self.__add_clients__(input, [(node, i)])
assert node.fgraph is self if not prevent_addition :
self.execute_callbacks('on_import', node, reason) assert node.fgraph is self
self.execute_callbacks('on_import', node, reason)
# change input # # change input #
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论