提交 8ad02e51 authored 作者: Reyhane Askari's avatar Reyhane Askari

removed outputguard when cycle_detection is fast(commit mostly removed after a rebase)

上级 2f8e387a
...@@ -120,6 +120,7 @@ class OutputGuard(ViewOp): ...@@ -120,6 +120,7 @@ class OutputGuard(ViewOp):
""" """
destroy_map = {0: [0]} destroy_map = {0: [0]}
view_map = {}
check_input = False check_input = False
......
...@@ -446,7 +446,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -446,7 +446,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def fast_destroy(self, app, reason, full=False): def fast_destroy(self, app, reason):
""" """
Do the check for only 1 level. Do the check for only 1 level.
...@@ -457,7 +457,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -457,7 +457,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
- But don't allow to destroy view - But don't allow to destroy view
""" """
dm = getattr(app.op, 'destroy_map', None) dm = getattr(app.op, 'destroy_map', None)
vm = getattr(app.op, 'view_map', {})
if not dm: if not dm:
return return
inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs inputs = set(itertools.chain.from_iterable(dm.values())) # list of app's destroyed inputs
...@@ -491,17 +490,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -491,17 +490,6 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# assert len(v) <= 1 # assert len(v) <= 1
# assert len(d) <= 1 # assert len(d) <= 1
# We don't want this to be a regular check. Full is only enabled when a node is
# attached to the graph and at the end of validation. We check a rare case where
# an apply node has a dmap and vmap on the same input. Right now the only case
# of this pattern is when app is the OutputGuard.
if full:
d_inputs = set(dmap for sublist in dm.values() for dmap in sublist)
v_inputs = set(vmap for sublist in vm.values() for vmap in sublist)
if d_inputs.intersection(v_inputs):
self.fail_validate[app] = theano.gof.InconsistencyError("\
Destroyed variable has both view_map and destroy_map. " + str(reason))
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
""" """
Add Apply instance to set which must be computed. Add Apply instance to set which must be computed.
...@@ -518,7 +506,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -518,7 +506,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if dmap: if dmap:
self.destroyers.add(app) self.destroyers.add(app)
if self.algo == 'fast': if self.algo == 'fast':
self.fast_destroy(app, reason, full=True) self.fast_destroy(app, reason)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(vmap): for o_idx, i_idx_list in iteritems(vmap):
...@@ -650,7 +638,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -650,7 +638,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# double check here. # double check here.
for app in app_err_pairs: for app in app_err_pairs:
if app in fgraph.apply_nodes: if app in fgraph.apply_nodes:
self.fast_destroy(app, 'validate', full=True) self.fast_destroy(app, 'validate')
if self.fail_validate: if self.fail_validate:
self.fail_validate = app_err_pairs self.fail_validate = app_err_pairs
raise app_err_pairs[app] raise app_err_pairs[app]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论