提交 1cef6dac authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for outputGuard

上级 ad23d387
...@@ -136,8 +136,7 @@ class Supervisor: ...@@ -136,8 +136,7 @@ class Supervisor:
if fgraph.has_destroyers(self.protected): if fgraph.has_destroyers(self.protected):
raise gof.InconsistencyError("Trying to destroy a protected" raise gof.InconsistencyError("Trying to destroy a protected"
"Variable.") "Variable.")
else: return True
return True
if not hasattr(fgraph, 'destroyers'): if not hasattr(fgraph, 'destroyers'):
return True return True
for r in self.protected + list(fgraph.outputs): for r in self.protected + list(fgraph.outputs):
......
...@@ -398,6 +398,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -398,6 +398,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
visited_app_set = set() visited_app_set = set()
def recursive_destroys_finder(protected_var): def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app.
for (app, idx) in protected_var.clients: for (app, idx) in protected_var.clients:
if app in visited_app_set: if app in visited_app_set:
continue continue
...@@ -406,16 +407,21 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -406,16 +407,21 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if app == 'output': if app == 'output':
continue continue
destroy_maps = getattr(app.op, 'destroy_map', {}).values() destroy_maps = getattr(app.op, 'destroy_map', {}).values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]: if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True return True
for var_idx in getattr(app.op, 'view_map', {}).keys(): for var_idx in getattr(app.op, 'view_map', {}).keys():
if idx in app.op.view_map[var_idx] and recursive_destroys_finder(app.outputs[var_idx]): if idx in app.op.view_map[var_idx]:
return True # We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
if recursive_destroys_finder(app.outputs[var_idx]):
return True
return False return False
for protected_var in protected_list: for protected_var in protected_list:
if recursive_destroys_finder(protected_var): if recursive_destroys_finder(protected_var):
return True return True
return False
fgraph.has_destroyers = has_destroyers fgraph.has_destroyers = has_destroyers
...@@ -445,7 +451,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -445,7 +451,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def fast_destroy(self, app, reason): def fast_destroy(self, app, reason, full=False):
""" """
Do the check for only 1 level. Do the check for only 1 level.
...@@ -456,6 +462,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -456,6 +462,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
- But don't allow to destroy view - But don't allow to destroy view
""" """
dm = getattr(app.op, 'destroy_map', None) dm = getattr(app.op, 'destroy_map', None)
vm = getattr(app.op, 'view_map', {})
if not dm: if not dm:
return return
inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs
...@@ -489,6 +496,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -489,6 +496,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# assert len(v) <= 1 # assert len(v) <= 1
# assert len(d) <= 1 # assert len(d) <= 1
# We don't want this to be a regular check. Full is only enabled when a node is
# attached to the graph and at the end of validation. We check a rare case where
# an apply node has a dmap and vmap on the same input. Right now the only case
# of this pattern is when app is the OutputGuard.
if full:
d_inputs = set(dmap for sublist in dm.values() for dmap in sublist)
v_inputs = set(vmap for sublist in vm.values() for vmap in sublist)
if d_inputs.intersection(v_inputs):
self.fail_validate[app] = theano.gof.InconsistencyError("\
Destroyed variable has both view_map and destroy_map. " + str(reason))
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
""" """
Add Apply instance to set which must be computed. Add Apply instance to set which must be computed.
...@@ -500,13 +518,15 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -500,13 +518,15 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list # If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', None): dmap = getattr(app.op, 'destroy_map', None)
vmap = getattr(app.op, 'view_map', {})
if dmap:
self.destroyers.add(app) self.destroyers.add(app)
if self.algo == 'fast': if self.algo == 'fast':
self.fast_destroy(app, reason) self.fast_destroy(app, reason, full=True)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})): for o_idx, i_idx_list in iteritems(vmap):
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
raise NotImplementedError( raise NotImplementedError(
'destroying this output invalidates multiple inputs', 'destroying this output invalidates multiple inputs',
...@@ -635,7 +655,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -635,7 +655,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# double check here. # double check here.
for app in app_err_pairs: for app in app_err_pairs:
if app in fgraph.apply_nodes: if app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate') self.fast_destroy(app, 'validate', full=True)
if self.fail_validate: if self.fail_validate:
self.fail_validate = app_err_pairs self.fail_validate = app_err_pairs
raise app_err_pairs[app] raise app_err_pairs[app]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论