提交 cf353889 authored 作者: Reyhane Askari's avatar Reyhane Askari

fix bugs of fast_destroy

上级 6100f425
......@@ -783,6 +783,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_revert(self, fgraph):
self.fail_validate = {}
def fast_destroy(self, app, reason):
"""
Do the check for only 1 level.
......@@ -800,15 +803,19 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs
for inp_idx in inputs:
inp = app.inputs[inp_idx]
if inp.owner:
if getattr(inp.tag, 'indestructible', False):
self.fail_validate[app] = InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
inp)
else:
if len(inp.clients) > 1:
self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has more than one client. " + str(reason))
else:
elif inp.owner:
app2 = inp.owner
inp_idx2 = app2.outputs.index(inp)
d = getattr(app2, 'destroy_map', {}).get(inp_idx2, [])
v = getattr(app2, 'view_map', {}).get(inp_idx2, [])
d = getattr(app2.op, 'destroy_map', {}).get(inp_idx2, [])
v = getattr(app2.op, 'view_map', {}).get(inp_idx2, [])
dv = d + v
assert len(dv) <= 1
if len(dv) > 0:
......@@ -820,17 +827,18 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
Add Apply instance to set which must be computed.
"""
if app in self.debug_all_apps:
raise ProtocolError("double import")
self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}):
if getattr(app.op, 'destroy_map', None):
self.destroyers.add(app)
if self.algo == 'fast':
self.fast_destroy(app, reason)
elif getattr(app.op, 'view_map', None) and self.algo == 'fast':
self.fast_destroy(app, reason)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})):
......@@ -889,6 +897,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
del self.view_o[i]
self.stale_droot = True
if app in self.fail_validate:
del self.fail_validate[app]
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""
......@@ -932,6 +942,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output)
if self.algo == 'fast':
if app in self.fail_validate:
del self.fail_validate[app]
self.fast_destroy(app, reason)
self.stale_droot = True
......@@ -956,9 +968,12 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# graph might have already changed when we raise the
# self.fail_validate error. So before raising the error, we
# double check here.
for app in app_err_pairs:
if app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate')
# for app in app_err_pairs:
# if app in fgraph.apply_nodes:
# self.fast_destroy(app, 'validate')
for app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate')
self.fail_validate = app_err_pairs
if self.fail_validate:
err = app_err_pairs.values()[0]
raise err
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论