提交 3f60d13f authored 作者: Reyhane Askari's avatar Reyhane Askari

changed flag name and minor change in fast_destroy

上级 4812ace3
......@@ -1478,10 +1478,10 @@ AddConfigVar('compile.wait',
IntParam(5, lambda i: i > 0, allow_override=False),
in_c_key=False)
AddConfigVar('disbale_cycle_detection',
AddConfigVar('cycle_detection',
"""If true it disables the cycle detection in graph.
""",
BoolParam(False),
StrParam('topo'),
in_c_key=False)
......
......@@ -8,6 +8,7 @@ from __future__ import absolute_import, print_function, division
from collections import deque, OrderedDict
from six import iteritems
import itertools
import theano
from theano import config
......@@ -793,14 +794,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
dm = getattr(app.op, 'destroy_map', None)
if not dm:
return
# inputs = sum(dm.values()) # list of app's destroyed inputs
inputs = dm.values()[0]
inputs = list(set(itertools.
chain.from_iterable(dm.values()))) # list of app's destroyed inputs
for inp_idx in inputs:
inp = app.inputs[inp_idx]
if inp.owner:
if len(inp.clients) > 1:
self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has too many clients")
"Destroyed variable has more than one client")
else:
app2 = inp.owner
inp_idx2 = app2.outputs.index(inp)
......@@ -810,7 +811,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
assert len(dv) <= 1
if len(v) > 0:
self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map or view_map")
"Destroyed variable has destroy_map")
def on_import(self, fgraph, app, reason):
"""
......@@ -825,9 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}):
# TODO: check here only one level of fast destroy_map.
self.destroyers.add(app)
if config.disbale_cycle_detection:
if config.cycle_detection == 'fast':
self.fast_destroy(app)
# add this symbol to the forward and backward maps
......@@ -929,8 +929,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output)
# TODO: check here only one level of fast destroy_map.
if config.disbale_cycle_detection:
if config.cycle_detection == 'fast':
self.fast_destroy(app)
self.stale_droot = True
......@@ -944,10 +943,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
if self.destroyers:
if config.disbale_cycle_detection and self.fail_validate:
self.fail_validate = False
# raise self.fail_validate
InconsistencyError("error")
if config.cycle_detection == 'fast':
if self.fail_validate:
err = self.fail_validate
self.fail_validate = False
raise err
ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论