提交 3cab3849 authored 作者: Reyhane Askari's avatar Reyhane Askari

minor speedup in fast_destroy

上级 cf353889
...@@ -783,9 +783,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -783,9 +783,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_revert(self, fgraph):
self.fail_validate = {}
def fast_destroy(self, app, reason): def fast_destroy(self, app, reason):
""" """
Do the check for only 1 level. Do the check for only 1 level.
...@@ -807,20 +804,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -807,20 +804,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.fail_validate[app] = InconsistencyError( self.fail_validate[app] = InconsistencyError(
"Attempting to destroy indestructible variables: %s" % "Attempting to destroy indestructible variables: %s" %
inp) inp)
else: elif len(inp.clients) > 1:
if len(inp.clients) > 1: self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has more than one client. " + str(reason))
elif inp.owner:
app2 = inp.owner
inp_idx2 = app2.outputs.index(inp)
v = getattr(app2.op, 'view_map', {})
if v:
v = v.get(inp_idx2, [])
if len(v) > 0:
self.fail_validate[app] = theano.gof.InconsistencyError( self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has more than one client. " + str(reason)) "Destroyed variable has view_map. " + str(reason))
elif inp.owner: d = getattr(app2.op, 'destroy_map', {})
app2 = inp.owner if d:
inp_idx2 = app2.outputs.index(inp) d = d.get(inp_idx2, [])
d = getattr(app2.op, 'destroy_map', {}).get(inp_idx2, []) if len(d) > 0:
v = getattr(app2.op, 'view_map', {}).get(inp_idx2, []) self.fail_validate[app] = theano.gof.InconsistencyError(
dv = d + v "Destroyed variable has destroy_map. " + str(reason))
assert len(dv) <= 1
if len(dv) > 0: assert len(v) <= 1
self.fail_validate[app] = theano.gof.InconsistencyError( assert len(d) <= 1
"Destroyed variable has destroy_map or view_map. " + str(reason))
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论