提交 3f60d13f authored 作者: Reyhane Askari's avatar Reyhane Askari

changed flag name and minor change in fast_destroy

上级 4812ace3
...@@ -1478,10 +1478,10 @@ AddConfigVar('compile.wait', ...@@ -1478,10 +1478,10 @@ AddConfigVar('compile.wait',
IntParam(5, lambda i: i > 0, allow_override=False), IntParam(5, lambda i: i > 0, allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('disbale_cycle_detection', AddConfigVar('cycle_detection',
"""If true it disables the cycle detection in graph. """If true it disables the cycle detection in graph.
""", """,
BoolParam(False), StrParam('topo'),
in_c_key=False) in_c_key=False)
......
...@@ -8,6 +8,7 @@ from __future__ import absolute_import, print_function, division ...@@ -8,6 +8,7 @@ from __future__ import absolute_import, print_function, division
from collections import deque, OrderedDict from collections import deque, OrderedDict
from six import iteritems from six import iteritems
import itertools
import theano import theano
from theano import config from theano import config
...@@ -793,14 +794,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -793,14 +794,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
dm = getattr(app.op, 'destroy_map', None) dm = getattr(app.op, 'destroy_map', None)
if not dm: if not dm:
return return
# inputs = sum(dm.values()) # list of app's destroyed inputs inputs = list(set(itertools.
inputs = dm.values()[0] chain.from_iterable(dm.values()))) # list of app's destroyed inputs
for inp_idx in inputs: for inp_idx in inputs:
inp = app.inputs[inp_idx] inp = app.inputs[inp_idx]
if inp.owner: if inp.owner:
if len(inp.clients) > 1: if len(inp.clients) > 1:
self.fail_validate = theano.gof.InconsistencyError( self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has too many clients") "Destroyed variable has more than one client")
else: else:
app2 = inp.owner app2 = inp.owner
inp_idx2 = app2.outputs.index(inp) inp_idx2 = app2.outputs.index(inp)
...@@ -810,7 +811,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -810,7 +811,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
assert len(dv) <= 1 assert len(dv) <= 1
if len(v) > 0: if len(v) > 0:
self.fail_validate = theano.gof.InconsistencyError( self.fail_validate = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map or view_map") "Destroyed variable has destroy_map")
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
""" """
...@@ -825,9 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -825,9 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# If it's a destructive op, add it to our watch list # If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', {}):
# TODO: check here only one level of fast destroy_map.
self.destroyers.add(app) self.destroyers.add(app)
if config.disbale_cycle_detection: if config.cycle_detection == 'fast':
self.fast_destroy(app) self.fast_destroy(app)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
...@@ -929,8 +929,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -929,8 +929,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.view_o.setdefault(new_r, OrderedSet()).add(output) self.view_o.setdefault(new_r, OrderedSet()).add(output)
# TODO: check here only one level of fast destroy_map. if config.cycle_detection == 'fast':
if config.disbale_cycle_detection:
self.fast_destroy(app) self.fast_destroy(app)
self.stale_droot = True self.stale_droot = True
...@@ -944,10 +943,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -944,10 +943,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
""" """
if self.destroyers: if self.destroyers:
if config.disbale_cycle_detection and self.fail_validate: if config.cycle_detection == 'fast':
if self.fail_validate:
err = self.fail_validate
self.fail_validate = False self.fail_validate = False
# raise self.fail_validate raise err
InconsistencyError("error")
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords): if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles") raise InconsistencyError("Dependency graph contains cycles")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论