提交 0c299535 authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for revert operation of validation with minor fixes

上级 3f60d13f
...@@ -106,7 +106,7 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -106,7 +106,7 @@ class AddDestroyHandler(gof.Optimizer):
"how was this output left unprotected against " "how was this output left unprotected against "
"destructive operations?" "destructive operations?"
% o) % o)
except gof.InconsistencyError: except gof.InconsistencyError as e:
# This output is already impossible to destroy. # This output is already impossible to destroy.
# No guard necessary # No guard necessary
pass pass
......
...@@ -1481,7 +1481,7 @@ AddConfigVar('compile.wait', ...@@ -1481,7 +1481,7 @@ AddConfigVar('compile.wait',
AddConfigVar('cycle_detection', AddConfigVar('cycle_detection',
"""If true it disables the cycle detection in graph. """If true it disables the cycle detection in graph.
""", """,
StrParam('topo'), StrParam(['topo', 'fast']),
in_c_key=False) in_c_key=False)
......
...@@ -780,7 +780,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -780,7 +780,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def fast_destroy(self, app): def fast_destroy(self, app, reason):
""" """
Do the check for only 1 level. Do the check for only 1 level.
...@@ -794,14 +794,13 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -794,14 +794,13 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
dm = getattr(app.op, 'destroy_map', None) dm = getattr(app.op, 'destroy_map', None)
if not dm: if not dm:
return return
inputs = list(set(itertools. inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs
chain.from_iterable(dm.values()))) # list of app's destroyed inputs
for inp_idx in inputs: for inp_idx in inputs:
inp = app.inputs[inp_idx] inp = app.inputs[inp_idx]
if inp.owner: if inp.owner:
if len(inp.clients) > 1: if len(inp.clients) > 1:
self.fail_validate = theano.gof.InconsistencyError( self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has more than one client") "Destroyed variable has more than one client. " + str(reason))
else: else:
app2 = inp.owner app2 = inp.owner
inp_idx2 = app2.outputs.index(inp) inp_idx2 = app2.outputs.index(inp)
...@@ -809,9 +808,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -809,9 +808,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
v = getattr(app2, 'view_map', {}).get(inp_idx2, []) v = getattr(app2, 'view_map', {}).get(inp_idx2, [])
dv = d + v dv = d + v
assert len(dv) <= 1 assert len(dv) <= 1
if len(v) > 0: if len(dv) > 0:
self.fail_validate = theano.gof.InconsistencyError( self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map") "Destroyed variable has destroy_map or view_map. " + str(reason))
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
""" """
...@@ -828,7 +827,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -828,7 +827,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', {}):
self.destroyers.add(app) self.destroyers.add(app)
if config.cycle_detection == 'fast': if config.cycle_detection == 'fast':
self.fast_destroy(app) self.fast_destroy(app, reason)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})): for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})):
...@@ -930,7 +929,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -930,7 +929,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output) self.view_o.setdefault(new_r, OrderedSet()).add(output)
if config.cycle_detection == 'fast': if config.cycle_detection == 'fast':
self.fast_destroy(app) self.fast_destroy(app, reason)
self.stale_droot = True self.stale_droot = True
def validate(self, fgraph): def validate(self, fgraph):
...@@ -947,7 +946,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -947,7 +946,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if self.fail_validate: if self.fail_validate:
err = self.fail_validate err = self.fail_validate
self.fail_validate = False self.fail_validate = False
for n in fgraph.apply_nodes:
self.fast_destroy(n, 'validate')
if self.fail_validate:
raise err raise err
else:
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords): if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles") raise InconsistencyError("Dependency graph contains cycles")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论