提交 ce3c7cda authored 作者: Reyhane Askari's avatar Reyhane Askari

changed self.fail_validate to a dictionary

上级 767f0d5a
...@@ -693,6 +693,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -693,6 +693,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
""" """
self.root_destroyer = OrderedDict() self.root_destroyer = OrderedDict()
self.fail_validate = OrderedDict()
def on_attach(self, fgraph): def on_attach(self, fgraph):
""" """
...@@ -740,7 +741,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -740,7 +741,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# clients: how many times does an apply use a given variable # clients: how many times does an apply use a given variable
self.clients = OrderedDict() # variable -> apply -> ninputs self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True self.stale_droot = True
self.fail_validate = False
self.debug_all_apps = OrderedSet() self.debug_all_apps = OrderedSet()
if self.do_imports_on_attach: if self.do_imports_on_attach:
...@@ -799,7 +799,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -799,7 +799,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
inp = app.inputs[inp_idx] inp = app.inputs[inp_idx]
if inp.owner: if inp.owner:
if len(inp.clients) > 1: if len(inp.clients) > 1:
self.fail_validate = theano.gof.InconsistencyError( self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has more than one client. " + str(reason)) "Destroyed variable has more than one client. " + str(reason))
else: else:
app2 = inp.owner app2 = inp.owner
...@@ -809,7 +809,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -809,7 +809,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
dv = d + v dv = d + v
assert len(dv) <= 1 assert len(dv) <= 1
if len(dv) > 0: if len(dv) > 0:
self.fail_validate = theano.gof.InconsistencyError( self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map or view_map. " + str(reason)) "Destroyed variable has destroy_map or view_map. " + str(reason))
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
...@@ -945,9 +945,16 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -945,9 +945,16 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if config.cycle_detection == 'fast': if config.cycle_detection == 'fast':
if self.fail_validate: if self.fail_validate:
err = self.fail_validate err = self.fail_validate
self.fail_validate = False self.fail_validate = {}
for n in fgraph.apply_nodes: # self.fail_validate can only be a hint that maybe/probably
self.fast_destroy(n, 'validate') # there is a cycle.This is because inside replace() we could
# record many reasons to not accept a change, but we don't
# know which one will fail first inside validate(). Thus,the
# graph might have already changed when we raise the
# self.fail_validate error. So before raising the error, we
# double check here.
for app in self.fail_validate:
self.fast_destroy(app, 'validate')
if self.fail_validate: if self.fail_validate:
raise err raise err
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论