提交 4812ace3 authored 作者: Reyhane Askari's avatar Reyhane Askari

added flag disbale_cycle_detection

上级 5f417aaa
...@@ -1478,6 +1478,12 @@ AddConfigVar('compile.wait', ...@@ -1478,6 +1478,12 @@ AddConfigVar('compile.wait',
IntParam(5, lambda i: i > 0, allow_override=False), IntParam(5, lambda i: i > 0, allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('disbale_cycle_detection',
"""If true it disables the cycle detection in graph.
""",
BoolParam(False),
in_c_key=False)
def _timeout_default(): def _timeout_default():
return theano.config.compile.wait * 24 return theano.config.compile.wait * 24
......
...@@ -10,6 +10,7 @@ from collections import deque, OrderedDict ...@@ -10,6 +10,7 @@ from collections import deque, OrderedDict
from six import iteritems from six import iteritems
import theano import theano
from theano import config
from . import toolbox from . import toolbox
from . import graph from . import graph
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -738,6 +739,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -738,6 +739,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# clients: how many times does an apply use a given variable # clients: how many times does an apply use a given variable
self.clients = OrderedDict() # variable -> apply -> ninputs self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True self.stale_droot = True
self.fail_validate = False
self.debug_all_apps = OrderedSet() self.debug_all_apps = OrderedSet()
if self.do_imports_on_attach: if self.do_imports_on_attach:
...@@ -788,24 +790,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -788,24 +790,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
- Allow sequence of view. - Allow sequence of view.
- But don't allow to destroy view - But don't allow to destroy view
""" """
return
dm = getattr(app.op, 'destroy_map', None) dm = getattr(app.op, 'destroy_map', None)
if not dm: if not dm:
return return
inputs = sum(dm.values()) # list of app's destroyed inputs # inputs = sum(dm.values()) # list of app's destroyed inputs
inputs = dm.values()[0]
for inp_idx in inputs: for inp_idx in inputs:
inp = app.inputs[inp_idx] inp = app.inputs[inp_idx]
if inp.owner: if inp.owner:
if len(inp.clients() > 1): if len(inp.clients) > 1:
raise InconsistencyError() self.fail_validate = theano.gof.InconsistencyError(
app2 = inp.owner "Destroyed variable has too many clients")
inp_idx2 = app2.outputs.index(inp) else:
d = getattr(app2, 'destroy_map', {}).get(inp_idx2, []) app2 = inp.owner
v = getattr(app2, 'view_map', {}).get(inp_idx2, []) inp_idx2 = app2.outputs.index(inp)
dv = d+v d = getattr(app2, 'destroy_map', {}).get(inp_idx2, [])
assert len(dv) <= 1 v = getattr(app2, 'view_map', {}).get(inp_idx2, [])
if len(v) > 0: dv = d + v
raise InconsistencyError() assert len(dv) <= 1
if len(v) > 0:
self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map or view_map")
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
""" """
...@@ -822,7 +827,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -822,7 +827,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', {}):
# TODO: check here only one level of fast destroy_map. # TODO: check here only one level of fast destroy_map.
self.destroyers.add(app) self.destroyers.add(app)
self.fast_destroy(app) if config.disbale_cycle_detection:
self.fast_destroy(app)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})): for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})):
...@@ -924,7 +930,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -924,7 +930,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output) self.view_o.setdefault(new_r, OrderedSet()).add(output)
# TODO: check here only one level of fast destroy_map. # TODO: check here only one level of fast destroy_map.
self.fast_destroy(app) if config.disbale_cycle_detection:
self.fast_destroy(app)
self.stale_droot = True self.stale_droot = True
def validate(self, fgraph): def validate(self, fgraph):
...@@ -937,8 +944,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -937,8 +944,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
""" """
if self.destroyers: if self.destroyers:
if config.disbale_cycle_detection and self.fail_validate:
self.fail_validate = False
# raise self.fail_validate
InconsistencyError("error")
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords): if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles") raise InconsistencyError("Dependency graph contains cycles")
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论