提交 1cef6dac authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for outputGuard

上级 ad23d387
......@@ -136,8 +136,7 @@ class Supervisor:
if fgraph.has_destroyers(self.protected):
raise gof.InconsistencyError("Trying to destroy a protected"
"Variable.")
else:
return True
return True
if not hasattr(fgraph, 'destroyers'):
return True
for r in self.protected + list(fgraph.outputs):
......
......@@ -398,6 +398,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
visited_app_set = set()
def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app.
for (app, idx) in protected_var.clients:
if app in visited_app_set:
continue
......@@ -406,16 +407,21 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if app == 'output':
continue
destroy_maps = getattr(app.op, 'destroy_map', {}).values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True
for var_idx in getattr(app.op, 'view_map', {}).keys():
if idx in app.op.view_map[var_idx] and recursive_destroys_finder(app.outputs[var_idx]):
return True
if idx in app.op.view_map[var_idx]:
# We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
if recursive_destroys_finder(app.outputs[var_idx]):
return True
return False
for protected_var in protected_list:
if recursive_destroys_finder(protected_var):
return True
return False
fgraph.has_destroyers = has_destroyers
......@@ -445,7 +451,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def fast_destroy(self, app, reason):
def fast_destroy(self, app, reason, full=False):
"""
Do the check for only 1 level.
......@@ -456,6 +462,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
- But don't allow to destroy view
"""
dm = getattr(app.op, 'destroy_map', None)
vm = getattr(app.op, 'view_map', {})
if not dm:
return
inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs
......@@ -489,6 +496,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# assert len(v) <= 1
# assert len(d) <= 1
# We don't want this to be a regular check. Full is only enabled when a node is
# attached to the graph and at the end of validation. We check a rare case where
# an apply node has a dmap and vmap on the same input. Right now the only case
# of this pattern is when app is the OutputGuard.
if full:
d_inputs = set(dmap for sublist in dm.values() for dmap in sublist)
v_inputs = set(vmap for sublist in vm.values() for vmap in sublist)
if d_inputs.intersection(v_inputs):
self.fail_validate[app] = theano.gof.InconsistencyError("\
Destroyed variable has both view_map and destroy_map. " + str(reason))
def on_import(self, fgraph, app, reason):
"""
Add Apply instance to set which must be computed.
......@@ -500,13 +518,15 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', None):
dmap = getattr(app.op, 'destroy_map', None)
vmap = getattr(app.op, 'view_map', {})
if dmap:
self.destroyers.add(app)
if self.algo == 'fast':
self.fast_destroy(app, reason)
self.fast_destroy(app, reason, full=True)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})):
for o_idx, i_idx_list in iteritems(vmap):
if len(i_idx_list) > 1:
raise NotImplementedError(
'destroying this output invalidates multiple inputs',
......@@ -635,7 +655,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# double check here.
for app in app_err_pairs:
if app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate')
self.fast_destroy(app, 'validate', full=True)
if self.fail_validate:
self.fail_validate = app_err_pairs
raise app_err_pairs[app]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论