提交 4812ace3 authored 作者: Reyhane Askari's avatar Reyhane Askari

added flag disbale_cycle_detection

上级 5f417aaa
...@@ -1478,6 +1478,12 @@ AddConfigVar('compile.wait', ...@@ -1478,6 +1478,12 @@ AddConfigVar('compile.wait',
IntParam(5, lambda i: i > 0, allow_override=False), IntParam(5, lambda i: i > 0, allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('disbale_cycle_detection',
"""If true it disables the cycle detection in graph.
""",
BoolParam(False),
in_c_key=False)
def _timeout_default(): def _timeout_default():
return theano.config.compile.wait * 24 return theano.config.compile.wait * 24
......
...@@ -10,6 +10,7 @@ from collections import deque, OrderedDict ...@@ -10,6 +10,7 @@ from collections import deque, OrderedDict
from six import iteritems from six import iteritems
import theano import theano
from theano import config
from . import toolbox from . import toolbox
from . import graph from . import graph
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -738,6 +739,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -738,6 +739,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# clients: how many times does an apply use a given variable # clients: how many times does an apply use a given variable
self.clients = OrderedDict() # variable -> apply -> ninputs self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True self.stale_droot = True
self.fail_validate = False
self.debug_all_apps = OrderedSet() self.debug_all_apps = OrderedSet()
if self.do_imports_on_attach: if self.do_imports_on_attach:
...@@ -788,24 +790,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -788,24 +790,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
- Allow sequence of view. - Allow sequence of view.
- But don't allow to destroy view - But don't allow to destroy view
""" """
return
dm = getattr(app.op, 'destroy_map', None) dm = getattr(app.op, 'destroy_map', None)
if not dm: if not dm:
return return
inputs = sum(dm.values()) # list of app's destroyed inputs # inputs = sum(dm.values()) # list of app's destroyed inputs
inputs = dm.values()[0]
for inp_idx in inputs: for inp_idx in inputs:
inp = app.inputs[inp_idx] inp = app.inputs[inp_idx]
if inp.owner: if inp.owner:
if len(inp.clients() > 1): if len(inp.clients) > 1:
raise InconsistencyError() self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has too many clients")
else:
app2 = inp.owner app2 = inp.owner
inp_idx2 = app2.outputs.index(inp) inp_idx2 = app2.outputs.index(inp)
d = getattr(app2, 'destroy_map', {}).get(inp_idx2, []) d = getattr(app2, 'destroy_map', {}).get(inp_idx2, [])
v = getattr(app2, 'view_map', {}).get(inp_idx2, []) v = getattr(app2, 'view_map', {}).get(inp_idx2, [])
dv = d+v dv = d + v
assert len(dv) <= 1 assert len(dv) <= 1
if len(v) > 0: if len(v) > 0:
raise InconsistencyError() self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map or view_map")
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
""" """
...@@ -822,6 +827,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -822,6 +827,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', {}):
# TODO: check here only one level of fast destroy_map. # TODO: check here only one level of fast destroy_map.
self.destroyers.add(app) self.destroyers.add(app)
if config.disbale_cycle_detection:
self.fast_destroy(app) self.fast_destroy(app)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
...@@ -924,6 +930,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -924,6 +930,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output) self.view_o.setdefault(new_r, OrderedSet()).add(output)
# TODO: check here only one level of fast destroy_map. # TODO: check here only one level of fast destroy_map.
if config.disbale_cycle_detection:
self.fast_destroy(app) self.fast_destroy(app)
self.stale_droot = True self.stale_droot = True
...@@ -937,8 +944,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -937,8 +944,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
""" """
if self.destroyers: if self.destroyers:
if config.disbale_cycle_detection and self.fail_validate:
self.fail_validate = False
# raise self.fail_validate
InconsistencyError("error")
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords): if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles") raise InconsistencyError("Dependency graph contains cycles")
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论