提交 4812ace3 authored 作者: Reyhane Askari's avatar Reyhane Askari

added flag disbale_cycle_detection

上级 5f417aaa
......@@ -1478,6 +1478,12 @@ AddConfigVar('compile.wait',
IntParam(5, lambda i: i > 0, allow_override=False),
in_c_key=False)
AddConfigVar('disbale_cycle_detection',
"""If true it disables the cycle detection in graph.
""",
BoolParam(False),
in_c_key=False)
def _timeout_default():
return theano.config.compile.wait * 24
......
......@@ -10,6 +10,7 @@ from collections import deque, OrderedDict
from six import iteritems
import theano
from theano import config
from . import toolbox
from . import graph
from theano.misc.ordered_set import OrderedSet
......@@ -738,6 +739,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# clients: how many times does an apply use a given variable
self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True
self.fail_validate = False
self.debug_all_apps = OrderedSet()
if self.do_imports_on_attach:
......@@ -788,24 +790,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
- Allow sequence of view.
- But don't allow to destroy view
"""
return
dm = getattr(app.op, 'destroy_map', None)
if not dm:
return
inputs = sum(dm.values()) # list of app's destroyed inputs
# inputs = sum(dm.values()) # list of app's destroyed inputs
inputs = dm.values()[0]
for inp_idx in inputs:
inp = app.inputs[inp_idx]
if inp.owner:
if len(inp.clients() > 1):
raise InconsistencyError()
app2 = inp.owner
inp_idx2 = app2.outputs.index(inp)
d = getattr(app2, 'destroy_map', {}).get(inp_idx2, [])
v = getattr(app2, 'view_map', {}).get(inp_idx2, [])
dv = d+v
assert len(dv) <= 1
if len(v) > 0:
raise InconsistencyError()
if len(inp.clients) > 1:
self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has too many clients")
else:
app2 = inp.owner
inp_idx2 = app2.outputs.index(inp)
d = getattr(app2, 'destroy_map', {}).get(inp_idx2, [])
v = getattr(app2, 'view_map', {}).get(inp_idx2, [])
dv = d + v
assert len(dv) <= 1
if len(v) > 0:
self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map or view_map")
def on_import(self, fgraph, app, reason):
"""
......@@ -822,7 +827,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if getattr(app.op, 'destroy_map', {}):
# TODO: check here only one level of fast destroy_map.
self.destroyers.add(app)
self.fast_destroy(app)
if config.disbale_cycle_detection:
self.fast_destroy(app)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})):
......@@ -924,7 +930,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output)
# TODO: check here only one level of fast destroy_map.
self.fast_destroy(app)
if config.disbale_cycle_detection:
self.fast_destroy(app)
self.stale_droot = True
def validate(self, fgraph):
......@@ -937,8 +944,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
if self.destroyers:
if config.disbale_cycle_detection and self.fail_validate:
self.fail_validate = False
# raise self.fail_validate
InconsistencyError("error")
ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles")
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论