提交 1860ab2e authored 作者: Reyhane Askari's avatar Reyhane Askari

change of names and fixes for recursive_destroys_finder

上级 5917ead5
...@@ -132,11 +132,12 @@ class Supervisor: ...@@ -132,11 +132,12 @@ class Supervisor:
self.protected = list(protected) self.protected = list(protected)
def validate(self, fgraph): def validate(self, fgraph):
if config.cycle_detection == 'fast' and hasattr(fgraph, 'fast_destroyers_check'): if config.cycle_detection == 'fast' and hasattr(fgraph, 'has_destroyers'):
if fgraph.fast_destroyers_check(self.protected): if fgraph.has_destroyers(self.protected):
raise gof.InconsistencyError("Trying to destroy a protected" raise gof.InconsistencyError("Trying to destroy a protected"
"Variable.") "Variable.")
else:
return True
if not hasattr(fgraph, 'destroyers'): if not hasattr(fgraph, 'destroyers'):
return True return True
for r in self.protected + list(fgraph.outputs): for r in self.protected + list(fgraph.outputs):
...@@ -1090,7 +1091,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1090,7 +1091,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs) all_graph_inputs = gof.graph.inputs(fgraph.outputs)
has_destroyers = hasattr(fgraph, 'get_destroyers_of') has_get_destroyers = hasattr(fgraph, 'get_destroyers_of')
for i in xrange(len(fgraph.outputs)): for i in xrange(len(fgraph.outputs)):
views_of_output_i = set() views_of_output_i = set()
...@@ -1121,7 +1122,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1121,7 +1122,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# being updated # being updated
if input_j in updated_fgraph_inputs: if input_j in updated_fgraph_inputs:
continue continue
if input_j in views_of_output_i and not (has_destroyers and fgraph.get_destroyers_of(input_j)): if input_j in views_of_output_i and not (has_get_destroyers and fgraph.get_destroyers_of(input_j)):
# We don't put deep_copy_op if the input and the # We don't put deep_copy_op if the input and the
# output have borrow==True # output have borrow==True
if input_j in fgraph.inputs: if input_j in fgraph.inputs:
......
...@@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
<unknown> <unknown>
""" """
pickle_rm_attr = ["destroyers", "fast_destroyers_check"] pickle_rm_attr = ["destroyers", "has_destroyers"]
def __init__(self, do_imports_on_attach=True, algo=None): def __init__(self, do_imports_on_attach=True, algo=None):
self.fgraph = None self.fgraph = None
...@@ -395,24 +395,23 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -395,24 +395,23 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
fgraph.destroyers = get_destroyers_of fgraph.destroyers = get_destroyers_of
def recursive_destroys_finder(clients_list): def recursive_destroys_finder(clients_list):
for client in clients_list: for (app, idx) in clients_list:
# client is a tuple (I don't know if its size is always one) if app == 'output':
for item in client: continue
if item.op.destroy_map: destroy_maps = getattr(app.op, 'destroy_map', {}).values()
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True
for var in getattr(app.op, 'view_map', {}).keys():
if recursive_destroys_finder(app.outputs[var].clients):
return True return True
if len(item.outputs) == 0:
return False
for output in item.outputs:
if recursive_destroys_finder(output.clients):
return True
return False return False
def fast_destroyers_check(protected_list): def has_destroyers(protected_list):
for protected_var in protected_list: for protected_var in protected_list:
if recursive_destroys_finder(protected_var.clients): if recursive_destroys_finder(protected_var.clients):
return True return True
fgraph.fast_destroyers_check = fast_destroyers_check fgraph.has_destroyers = has_destroyers
def refresh_droot_impact(self): def refresh_droot_impact(self):
""" """
...@@ -436,7 +435,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -436,7 +435,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
del self.stale_droot del self.stale_droot
assert self.fgraph.destroyer_handler is self assert self.fgraph.destroyer_handler is self
delattr(self.fgraph, 'destroyers') delattr(self.fgraph, 'destroyers')
delattr(self.fgraph, 'fast_destroyers_check') delattr(self.fgraph, 'has_destroyers')
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论