提交 9625e083 authored 作者: sentient07's avatar sentient07

added validate method to ReplaceValidate

上级 45753edd
...@@ -320,11 +320,10 @@ class FunctionGraph(utils.object2): ...@@ -320,11 +320,10 @@ class FunctionGraph(utils.object2):
if output.clients or output in self.outputs] if output.clients or output in self.outputs]
# If the apply node is not used and is not an output # If the apply node is not used and is not an output
if not used_or_output: if not used_or_output:
if apply_node in self.apply_nodes:
#keeping track of removed apply node
self.apply_nodes.remove(apply_node) self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs) self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason) self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs): for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)], self.__remove_clients__(input, [(apply_node, i)],
reason=reason) reason=reason)
......
...@@ -267,6 +267,7 @@ class ReplaceValidate(History, Validator): ...@@ -267,6 +267,7 @@ class ReplaceValidate(History, Validator):
"replace_all_validate_remove"] + "replace_all_validate_remove"] +
History.pickle_rm_attr + Validator.pickle_rm_attr) History.pickle_rm_attr + Validator.pickle_rm_attr)
_nodes_removed = set() _nodes_removed = set()
fail_validate = False
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('replace_validate', 'replace_all_validate', for attr in ('replace_validate', 'replace_all_validate',
...@@ -373,7 +374,15 @@ class ReplaceValidate(History, Validator): ...@@ -373,7 +374,15 @@ class ReplaceValidate(History, Validator):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
if node in self._nodes_removed: if node in self._nodes_removed:
raise theano.gof.InconsistencyError("Trying to introduce a removed node") self.fail_validate = True
def validate(self, fgraph):
if not hasattr(fgraph, 'destroyers'):
return True
if self.fail_validate:
self.fail_validate = False
raise theano.gof.InconsistencyError("Trying to reintroduce a removed node")
class NodeFinder(Bookkeeper): class NodeFinder(Bookkeeper):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论