提交 45753edd authored 作者: sentient07's avatar sentient07

Made the suggested changes

上级 74d2d21a
...@@ -153,7 +153,6 @@ class FunctionGraph(utils.object2): ...@@ -153,7 +153,6 @@ class FunctionGraph(utils.object2):
self.inputs = list(inputs) self.inputs = list(inputs)
self.outputs = outputs self.outputs = outputs
self._removed_nodes = set()
for f in features: for f in features:
self.attach_feature(f) self.attach_feature(f)
...@@ -326,7 +325,6 @@ class FunctionGraph(utils.object2): ...@@ -326,7 +325,6 @@ class FunctionGraph(utils.object2):
self.apply_nodes.remove(apply_node) self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs) self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason) self.execute_callbacks('on_prune', apply_node, reason)
self._removed_nodes.add(apply_node)
for i, input in enumerate(apply_node.inputs): for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)], self.__remove_clients__(input, [(apply_node, i)],
reason=reason) reason=reason)
...@@ -481,11 +479,6 @@ class FunctionGraph(utils.object2): ...@@ -481,11 +479,6 @@ class FunctionGraph(utils.object2):
for node in new_nodes: for node in new_nodes:
assert node not in self.apply_nodes assert node not in self.apply_nodes
prevent_addition = node in self._removed_nodes
if prevent_addition :
raise InconsistencyError
("Trying to reintroduce an old nodes in the graph. This should not happen")
self.__setup_node__(node) self.__setup_node__(node)
self.apply_nodes.add(node) self.apply_nodes.add(node)
for output in node.outputs: for output in node.outputs:
......
...@@ -266,6 +266,7 @@ class ReplaceValidate(History, Validator): ...@@ -266,6 +266,7 @@ class ReplaceValidate(History, Validator):
pickle_rm_attr = (["replace_validate", "replace_all_validate", pickle_rm_attr = (["replace_validate", "replace_all_validate",
"replace_all_validate_remove"] + "replace_all_validate_remove"] +
History.pickle_rm_attr + Validator.pickle_rm_attr) History.pickle_rm_attr + Validator.pickle_rm_attr)
_nodes_removed = set()
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('replace_validate', 'replace_all_validate', for attr in ('replace_validate', 'replace_all_validate',
...@@ -347,6 +348,7 @@ class ReplaceValidate(History, Validator): ...@@ -347,6 +348,7 @@ class ReplaceValidate(History, Validator):
""" """
chk = fgraph.replace_all_validate(replacements, reason) chk = fgraph.replace_all_validate(replacements, reason)
self._nodes_removed.update(remove)
for rm in remove: for rm in remove:
if rm in fgraph.apply_nodes or rm in fgraph.variables: if rm in fgraph.apply_nodes or rm in fgraph.variables:
fgraph.revert(chk) fgraph.revert(chk)
...@@ -369,6 +371,9 @@ class ReplaceValidate(History, Validator): ...@@ -369,6 +371,9 @@ class ReplaceValidate(History, Validator):
del d["history"] del d["history"]
return d return d
def on_import(self, fgraph, node, reason):
if node in self._nodes_removed:
raise theano.gof.InconsistencyError("Trying to introduce a removed node")
class NodeFinder(Bookkeeper): class NodeFinder(Bookkeeper):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论