提交 cf353889 authored 作者: Reyhane Askari's avatar Reyhane Askari

fix bugs of fast_destroy

上级 6100f425
...@@ -783,6 +783,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -783,6 +783,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_revert(self, fgraph):
self.fail_validate = {}
def fast_destroy(self, app, reason): def fast_destroy(self, app, reason):
""" """
Do the check for only 1 level. Do the check for only 1 level.
...@@ -800,15 +803,19 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -800,15 +803,19 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs
for inp_idx in inputs: for inp_idx in inputs:
inp = app.inputs[inp_idx] inp = app.inputs[inp_idx]
if inp.owner: if getattr(inp.tag, 'indestructible', False):
self.fail_validate[app] = InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
inp)
else:
if len(inp.clients) > 1: if len(inp.clients) > 1:
self.fail_validate[app] = theano.gof.InconsistencyError( self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has more than one client. " + str(reason)) "Destroyed variable has more than one client. " + str(reason))
else: elif inp.owner:
app2 = inp.owner app2 = inp.owner
inp_idx2 = app2.outputs.index(inp) inp_idx2 = app2.outputs.index(inp)
d = getattr(app2, 'destroy_map', {}).get(inp_idx2, []) d = getattr(app2.op, 'destroy_map', {}).get(inp_idx2, [])
v = getattr(app2, 'view_map', {}).get(inp_idx2, []) v = getattr(app2.op, 'view_map', {}).get(inp_idx2, [])
dv = d + v dv = d + v
assert len(dv) <= 1 assert len(dv) <= 1
if len(dv) > 0: if len(dv) > 0:
...@@ -820,17 +827,18 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -820,17 +827,18 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
Add Apply instance to set which must be computed. Add Apply instance to set which must be computed.
""" """
if app in self.debug_all_apps: if app in self.debug_all_apps:
raise ProtocolError("double import") raise ProtocolError("double import")
self.debug_all_apps.add(app) self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list # If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', None):
self.destroyers.add(app) self.destroyers.add(app)
if self.algo == 'fast': if self.algo == 'fast':
self.fast_destroy(app, reason) self.fast_destroy(app, reason)
elif getattr(app.op, 'view_map', None) and self.algo == 'fast':
self.fast_destroy(app, reason)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})): for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})):
...@@ -889,6 +897,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -889,6 +897,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
del self.view_o[i] del self.view_o[i]
self.stale_droot = True self.stale_droot = True
if app in self.fail_validate:
del self.fail_validate[app]
def on_change_input(self, fgraph, app, i, old_r, new_r, reason): def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
""" """
...@@ -932,6 +942,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -932,6 +942,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output) self.view_o.setdefault(new_r, OrderedSet()).add(output)
if self.algo == 'fast': if self.algo == 'fast':
if app in self.fail_validate:
del self.fail_validate[app]
self.fast_destroy(app, reason) self.fast_destroy(app, reason)
self.stale_droot = True self.stale_droot = True
...@@ -956,9 +968,12 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -956,9 +968,12 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# graph might have already changed when we raise the # graph might have already changed when we raise the
# self.fail_validate error. So before raising the error, we # self.fail_validate error. So before raising the error, we
# double check here. # double check here.
for app in app_err_pairs: # for app in app_err_pairs:
if app in fgraph.apply_nodes: # if app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate') # self.fast_destroy(app, 'validate')
for app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate')
self.fail_validate = app_err_pairs
if self.fail_validate: if self.fail_validate:
err = app_err_pairs.values()[0] err = app_err_pairs.values()[0]
raise err raise err
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论