提交 492653e9 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

started IncrementalDestroyHandler class

上级 cda264aa
...@@ -241,7 +241,368 @@ def fast_inplace_check(inputs): ...@@ -241,7 +241,368 @@ def fast_inplace_check(inputs):
and i not in protected_inputs] and i not in protected_inputs]
return inputs return inputs
class DestroyHandler(toolbox.Bookkeeper): if 0:
# old, non-incremental version of the DestroyHandler
class DestroyHandler(toolbox.Bookkeeper):
"""
The DestroyHandler class detects when a graph is impossible to evaluate because of
aliasing and destructive operations.
Several data structures are used to do this.
When an Op uses its view_map property to declare that an output may be aliased
to an input, then if that output is destroyed, the input is also considering to be
destroyed. The view_maps of several Ops can feed into one another and form a directed graph.
The consequence of destroying any variable in such a graph is that all variables in the graph
must be considered to be destroyed, because they could all be refering to the same
underlying storage. In the current implementation, that graph is a tree, and the root of
that tree is called the foundation. The `droot` property of this class maps from every
graph variable to its foundation. The `impact` property maps backward from the foundation
to all of the variables that depend on it. When any variable is destroyed, this class marks
the foundation of that variable as being destroyed, with the `root_destroyer` property.
"""
droot = {}
"""
destroyed view + nonview variables -> foundation
"""
impact = {}
"""
destroyed nonview variable -> it + all views of it
"""
root_destroyer = {}
"""
root -> destroyer apply
"""
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
def on_attach(self, fgraph):
"""
When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one)
2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing
compilation to be slower.
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
already_there = False
if self.fgraph is fgraph:
already_there = True
if self.fgraph is not None:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
####### end of checking ############
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
try:
return [root_destroyer[droot[r]]]
except Exception:
return []
fgraph.destroyers = get_destroyers_of
fgraph.destroy_handler = self
self.fgraph = fgraph
self.destroyers = set() #set of Apply instances with non-null destroy_map
self.view_i = {} # variable -> variable used in calculation
self.view_o = {} # variable -> set of variables that use this one as a direct input
#clients: how many times does an apply use a given variable
self.clients = {} # variable -> apply -> ninputs
self.stale_droot = True
# IG: It's unclear if this is meant to be included in deployed code. It looks like
# it is unnecessary if FunctionGraph is working correctly, so I am commenting uses
# of it (for speed) but leaving the commented code in place so it is easy to restore
# for debugging purposes.
# Note: is there anything like the C preprocessor for python? It would be useful to
# just ifdef these things out
# self.debug_all_apps = set()
if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, fgraph)
def refresh_droot_impact(self):
if self.stale_droot:
self.droot, self.impact, self.root_destroyer = self._build_droot_impact()
self.stale_droot = False
return self.droot, self.impact, self.root_destroyer
def _build_droot_impact(self):
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
droot[v] = input_root
impact[input_root] = input_impact
impact[input_root].add(input_root)
return droot, impact, root_destroyer
def on_detach(self, fgraph):
if fgraph is not self.fgraph:
raise Exception("detaching wrong fgraph", fgraph)
del self.destroyers
del self.view_i
del self.view_o
del self.clients
del self.stale_droot
assert self.fgraph.destroyer_handler is self
delattr(self.fgraph, 'destroyers')
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
#self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}):
self.destroyers.add(app)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
self.view_o.setdefault(i,set()).add(o)
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
self.clients.setdefault(output, {})
self.stale_droot = True
def on_prune(self, fgraph, app):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app)
#UPDATE self.clients
for i, input in enumerate(set(app.inputs)):
del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}):
self.destroyers.remove(app)
# Note: leaving empty client dictionaries in the struct.
# Why? It's a pain to remove them. I think they aren't doing any harm, they will be
# deleted on_detach().
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
del self.view_i[o]
self.view_o[i].remove(o)
if not self.view_o[i]:
del self.view_o[i]
self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r):
"""app.inputs[i] changed from old_r to new_r """
if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import")
#UPDATE self.clients
self.clients[old_r][app] -= 1
if self.clients[old_r][app] == 0:
del self.clients[old_r][app]
self.clients.setdefault(new_r,{}).setdefault(app,0)
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
i_idx = i_idx_list[0]
output = app.outputs[o_idx]
if i_idx == i:
if app.inputs[i_idx] is not new_r:
raise ProtocolError("wrong new_r on change")
self.view_i[output] = new_r
self.view_o[old_r].remove(output)
if not self.view_o[old_r]:
del self.view_o[old_r]
self.view_o.setdefault(new_r,set()).add(output)
self.stale_droot = True
def validate(self, fgraph):
"""Return None
Raise InconsistencyError when
a) orderings() raises an error
b) orderings cannot be topologically sorted.
"""
if self.destroyers:
ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles")
else:
#James's Conjecture:
#If there are no destructive ops, then there can be no cycles.
pass
return True
def orderings(self, fgraph):
"""Return orderings induced by destructive operations.
Raise InconsistencyError when
a) attempting to destroy indestructable variable, or
b) attempting to destroy a value multiple times, or
c) an Apply destroys (illegally) one of its own inputs by aliasing
"""
rval = {}
if self.destroyers:
# BUILD DATA STRUCTURES
# CHECK for multiple destructions during construction of variables
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag,'indestructible', False) or \
isinstance(r, graph.Constant)]
if illegal_destroy:
#print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies
for app in self.destroyers:
# for each destroyed input...
for output_idx, input_idx_list in app.op.destroy_map.items():
destroyed_idx = input_idx_list[0]
destroyed_variable = app.inputs[destroyed_idx]
root = droot[destroyed_variable]
root_impact = impact[root]
# we generally want to put all clients of things which depend on root
# as pre-requisites of app.
# But, app is itself one such client!
# App will always be a client of the node we're destroying
# (destroyed_variable, but the tricky thing is when it is also a client of
# *another variable* viewing on the root. Generally this is illegal, (e.g.,
# add_inplace(x, x.T). In some special cases though, the in-place op will
# actually be able to work properly with multiple destroyed inputs (e.g,
# add_inplace(x, x). An Op that can still work in this case should declare
# so via the 'destroyhandler_tolerate_same' attribute or
# 'destroyhandler_tolerate_aliased' attribute.
#
# destroyhandler_tolerate_same should be a list of pairs of the form
# [(idx0, idx1), (idx0, idx2), ...]
# The first element of each pair is the input index of a destroyed
# variable.
# The second element of each pair is the index of a different input where
# we will permit exactly the same variable to appear.
# For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
# input is also allowed to appear as the second argument.
#
# destroyhandler_tolerate_aliased is the same sort of list of
# pairs.
# op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
# destroyhandler to IGNORE an aliasing between a destroyed
# input idx0 and another input idx1.
# This is generally a bad idea, but it is safe in some
# cases, such as
# - the op reads from the aliased idx1 before modifying idx0
# - the idx0 and idx1 are guaranteed not to overlap (e.g.
# they are pointed at different rows of a matrix).
#
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = set(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated
#print 'ignored', ignored
for i, input in enumerate(app.inputs):
if i in ignored:
continue
if input in root_impact \
and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
root_clients = set()
for r in root_impact:
assert not [a for a,c in self.clients[r].items() if not c]
root_clients.update([a for a,c in self.clients[r].items() if c])
root_clients.remove(app)
if root_clients:
rval[app] = root_clients
return rval
class IncrementalDestroyHandler(toolbox.Bookkeeper):
""" """
The DestroyHandler class detects when a graph is impossible to evaluate because of The DestroyHandler class detects when a graph is impossible to evaluate because of
aliasing and destructive operations. aliasing and destructive operations.
...@@ -578,3 +939,4 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -578,3 +939,4 @@ class DestroyHandler(toolbox.Bookkeeper):
return rval return rval
DestroyHandler = IncrementalDestroyHandler
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论