提交 d7d722fa authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3200 from t13m/inline_opt_destroyhandler

Inline optimization for destroyhandler.py
...@@ -13,6 +13,7 @@ from theano.compat import OrderedDict ...@@ -13,6 +13,7 @@ from theano.compat import OrderedDict
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
from .fg import InconsistencyError from .fg import InconsistencyError
from six.moves.queue import Queue
class ProtocolError(Exception): class ProtocolError(Exception):
...@@ -178,45 +179,51 @@ def _contains_cycle(fgraph, orderings): ...@@ -178,45 +179,51 @@ def _contains_cycle(fgraph, orderings):
return visited != len(parent_counts) return visited != len(parent_counts)
def getroot(r, view_i): def _build_droot_impact(destroy_handler):
""" droot = {} # destroyed view + nonview variables -> foundation
TODO: what is view_i ? based on add_impact's docstring, IG is guessing impact = {} # destroyed nonview variable -> it + all views of it
it might be a dictionary mapping variables to views, but what is root_destroyer = {} # root -> destroyer apply
a view? In these old docstrings I'm not sure if "view" always
means "view variable" or if it also sometimes means "viewing
pattern."
For views: Return non-view variable which is ultimatly viewed by r.
For non-views: return self.
"""
try:
return getroot(view_i[r], view_i)
except KeyError:
return r
for app in destroy_handler.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
# Find non-view variable which is ultimatly viewed by input.
view_i = destroy_handler.view_i
_r = input
while _r is not None:
r = _r
_r = view_i.get(r)
input_root = r
if input_root in droot:
raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
def add_impact(r, view_o, impact): # The code here add all the variables that are views of r into
""" # an OrderedSet input_impact
In opposition to getroot, which finds the variable that is viewed *by* r, this function input_impact = OrderedSet()
returns all the variables that are views of r. queue = Queue()
queue.put(input_root)
:param impact: is a set of variables that are views of r while not queue.empty():
:param droot: a dictionary mapping views -> r v = queue.get()
for n in destroy_handler.view_o.get(v, []):
input_impact.add(n)
queue.put(n)
TODO: this docstring is hideously wrong, the function doesn't return anything. for v in input_impact:
has droot been renamed to view_o? assert v not in droot
does it add things to the impact argument instead of returning them? droot[v] = input_root
IG thinks so, based on reading the code. It looks like get_impact
does what this docstring said this function does.
"""
for v in view_o.get(r, []):
impact.add(v)
add_impact(v, view_o, impact)
impact[input_root] = input_impact
impact[input_root].add(input_root)
def get_impact(root, view_o): return droot, impact, root_destroyer
impact = OrderedSet()
add_impact(root, view_o, impact)
return impact
def fast_inplace_check(inputs): def fast_inplace_check(inputs):
...@@ -338,39 +345,10 @@ if 0: ...@@ -338,39 +345,10 @@ if 0:
def refresh_droot_impact(self): def refresh_droot_impact(self):
if self.stale_droot: if self.stale_droot:
self.droot, self.impact, self.root_destroyer = self._build_droot_impact() self.droot, self.impact, self.root_destroyer = _build_droot_impact(self)
self.stale_droot = False self.stale_droot = False
return self.droot, self.impact, self.root_destroyer return self.droot, self.impact, self.root_destroyer
def _build_droot_impact(self):
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in iteritems(app.op.destroy_map):
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
# input_impact = set([input_root])
# add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
droot[v] = input_root
impact[input_root] = input_impact
impact[input_root].add(input_root)
return droot, impact, root_destroyer
def on_detach(self, fgraph): def on_detach(self, fgraph):
if fgraph is not self.fgraph: if fgraph is not self.fgraph:
raise Exception("detaching wrong fgraph", fgraph) raise Exception("detaching wrong fgraph", fgraph)
...@@ -750,30 +728,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -750,30 +728,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
(see docstrings for these properties above) (see docstrings for these properties above)
""" """
if self.stale_droot: if self.stale_droot:
droot = OrderedDict() # destroyed view + nonview variables -> foundation self.droot, self.impact, self.root_destroyer =\
impact = OrderedDict() # destroyed nonview variable -> it + all views of it _build_droot_impact(self)
root_destroyer = OrderedDict() # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in iteritems(app.op.destroy_map):
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
droot[v] = input_root
impact[input_root] = input_impact
impact[input_root].add(input_root)
self.droot, self.impact, self.root_destroyer = droot, impact, root_destroyer
self.stale_droot = False self.stale_droot = False
return self.droot, self.impact, self.root_destroyer return self.droot, self.impact, self.root_destroyer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论