提交 9625e083 authored 作者: sentient07's avatar sentient07

added validate method to ReplaceValidate

上级 45753edd
......@@ -320,14 +320,13 @@ class FunctionGraph(utils.object2):
if output.clients or output in self.outputs]
# If the apply node is not used and is not an output
if not used_or_output:
if apply_node in self.apply_nodes:
#keeping track of removed apply node
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)],
reason=reason)
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)],
reason=reason)
# variable should not have any clients.
# assert not variable.clients
......
......@@ -267,6 +267,7 @@ class ReplaceValidate(History, Validator):
"replace_all_validate_remove"] +
History.pickle_rm_attr + Validator.pickle_rm_attr)
_nodes_removed = set()
fail_validate = False
def on_attach(self, fgraph):
for attr in ('replace_validate', 'replace_all_validate',
......@@ -373,7 +374,15 @@ class ReplaceValidate(History, Validator):
def on_import(self, fgraph, node, reason):
if node in self._nodes_removed:
raise theano.gof.InconsistencyError("Trying to introduce a removed node")
self.fail_validate = True
def validate(self, fgraph):
if not hasattr(fgraph, 'destroyers'):
return True
if self.fail_validate:
self.fail_validate = False
raise theano.gof.InconsistencyError("Trying to reintroduce a removed node")
class NodeFinder(Bookkeeper):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论