提交 8ad02e51 authored 作者: Reyhane Askari's avatar Reyhane Askari

removed outputguard when cycle_detection is fast(commit mostly removed after a rebase)

上级 2f8e387a
......@@ -120,6 +120,7 @@ class OutputGuard(ViewOp):
"""
destroy_map = {0: [0]}
view_map = {}
check_input = False
......
......@@ -446,7 +446,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def fast_destroy(self, app, reason, full=False):
def fast_destroy(self, app, reason):
"""
Do the check for only 1 level.
......@@ -457,7 +457,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
- But don't allow to destroy view
"""
dm = getattr(app.op, 'destroy_map', None)
vm = getattr(app.op, 'view_map', {})
if not dm:
return
inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs
......@@ -491,17 +490,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# assert len(v) <= 1
# assert len(d) <= 1
# We don't want this to be a regular check. Full is only enabled when a node is
# attached to the graph and at the end of validation. We check a rare case where
# an apply node has a dmap and vmap on the same input. Right now the only case
# of this pattern is when app is the OutputGuard.
if full:
d_inputs = set(dmap for sublist in dm.values() for dmap in sublist)
v_inputs = set(vmap for sublist in vm.values() for vmap in sublist)
if d_inputs.intersection(v_inputs):
self.fail_validate[app] = theano.gof.InconsistencyError("\
Destroyed variable has both view_map and destroy_map. " + str(reason))
def on_import(self, fgraph, app, reason):
"""
Add Apply instance to set which must be computed.
......@@ -518,7 +506,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if dmap:
self.destroyers.add(app)
if self.algo == 'fast':
self.fast_destroy(app, reason, full=True)
self.fast_destroy(app, reason)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(vmap):
......@@ -650,7 +638,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# double check here.
for app in app_err_pairs:
if app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate', full=True)
self.fast_destroy(app, 'validate')
if self.fail_validate:
self.fail_validate = app_err_pairs
raise app_err_pairs[app]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论