moved compute and DestroyHandler to lib, compute() uses history set now

上级 8f0c2cfe
......@@ -11,7 +11,7 @@ import graph
#TODO: move mark_outputs_as_destroyed to the place that uses this function
#TODO: move Return to where it is used.
__all__ = ['DestroyHandler', 'IONames', 'mark_outputs_as_destroyed']
__all__ = ['IONames', 'mark_outputs_as_destroyed']
class IONames:
......@@ -78,253 +78,6 @@ class IONames:
class DestroyHandler(Listener, Constraint, Orderings):
def __init__(self, env):
self.parent = {}
self.children = {}
self.destroyers = {}
self.paths = {}
self.dups = set()
self.cycles = set()
self.env = env
for input in env.inputs:
# self.parent[input] = None
self.children[input] = set()
def __path__(self, r):
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent.get(r, None) ### ???
while r:
rval.append(r)
r = self.parent.get(r, None)
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
children = self.children[r]
if not children:
return set([r])
else:
rval = set([r])
for child in children:
rval.update(self.__views__(child))
return rval
def __users__(self, r):
views = self.__views__(r)
rval = set()
for view in views:
for op, i in self.env.clients(view):
rval.update(op.outputs)
return rval
def __pre__(self, op):
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
# if not keep_going:
# return set()
rval.update(op.inputs)
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
# print "!! ", r, seq
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
# print "!!! ", start
users = self.__users__(start)
users.add(start)
for user in users:
for cycle in copy(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
vmap = getattr(op, 'view_map',{})
dmap = getattr(op, 'destoy_map', {})
return vmap, dmap
def on_import(self, op):
view_map, destroy_map = self.get_maps(op)
# for input in op.inputs:
# self.parent.setdefault(input, None)
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
# self.parent[output] = None
if isinstance(destroyed, Result):
destroyed = [destroyed]
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
elif views:
if isinstance(views, Result):
views = [views]
if len(views) > 1: #views was inputs before?
raise Exception("Output is a view of too many inputs.")
self.parent[output] = views[0]
for input in views:
self.children[input].add(output)
# else:
# self.parent[output] = None
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
# if destroy_map:
# print "op: ", op
# print "ord: ", [str(x) for x in self.orderings()[op]]
# print
def on_prune(self, op):
view_map, destroy_map = self.get_maps(op)
if destroy_map:
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
self.__detect_cycles__(output, True)
for i, output in enumerate(op.outputs):
try:
del self.parent[output]
except:
pass
del self.children[output]
def __add_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
if len(destroyers) > 1:
self.dups.add(foundation)
def __remove_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers[foundation]
del destroyers[op]
if not destroyers:
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
prev = set()
for op, i in clients:
prev.update(op.outputs)
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_2 in inputs:
assert self.parent.get(output, None) == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
self.__detect_cycles__(r_1)
self.__detect_cycles__(r_2)
def validate(self):
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
else:
return True
def orderings(self):
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
class Return(DummyOp):
"""
......
......@@ -13,17 +13,10 @@ __all__ = ['Feature',
'Constraint',
'Orderings',
'Tool',
# 'Preprocessor',
'EquivTool',
'InstanceFinder',
'PrintListener',
'ChangeListener',
# 'DestroyPreprocessor',
# 'DestroyHandler'
]
......
......@@ -33,6 +33,31 @@ def make_static(cls, fname):
f = f.im_func
setattr(cls, fname, staticmethod(f))
def compute_from(nodes, history):
"""Recursively evaluate each node (in a quick & dirty way).
history (aka inputs) is a set of nodes that need not be [re]computed.
TODO: make this more correct by building a little graph and executing it.
The current implementation doesn't take into account any ordering
constraints imposed by destructors, for example.
"""
def compute_recursive(node):
if node and (node not in history):
if hasattr(node, 'owner'): #node is storage
compute_recursive(node.owner)
else: #node is op
for input in node.inputs:
compute_recursive(input)
node.perform()
history.add(node)
for n in nodes:
compute_recursive(n)
def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set())
class ForbidConstantOverwrite(features.Listener, features.Constraint):
......@@ -158,11 +183,260 @@ class ResultValue(Result):
def alloc(self): raise AbstractFunctionError()
class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
def __init__(self, env):
self.parent = {}
self.children = {}
self.destroyers = {}
self.paths = {}
self.dups = set()
self.cycles = set()
self.env = env
for input in env.inputs:
# self.parent[input] = None
self.children[input] = set()
def __path__(self, r):
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent.get(r, None) ### ???
while r:
rval.append(r)
r = self.parent.get(r, None)
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
children = self.children[r]
if not children:
return set([r])
else:
rval = set([r])
for child in children:
rval.update(self.__views__(child))
return rval
def __users__(self, r):
views = self.__views__(r)
rval = set()
for view in views:
for op, i in self.env.clients(view):
rval.update(op.outputs)
return rval
def __pre__(self, op):
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
# if not keep_going:
# return set()
rval.update(op.inputs)
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
# print "!! ", r, seq
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
# print "!!! ", start
users = self.__users__(start)
users.add(start)
for user in users:
for cycle in copy(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
vmap = getattr(op, 'view_map',{})
dmap = getattr(op, 'destoy_map', {})
return vmap, dmap
def on_import(self, op):
view_map, destroy_map = self.get_maps(op)
# for input in op.inputs:
# self.parent.setdefault(input, None)
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
# self.parent[output] = None
if isinstance(destroyed, Result):
destroyed = [destroyed]
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
elif views:
if isinstance(views, Result):
views = [views]
if len(views) > 1: #views was inputs before?
raise Exception("Output is a view of too many inputs.")
self.parent[output] = views[0]
for input in views:
self.children[input].add(output)
# else:
# self.parent[output] = None
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
# if destroy_map:
# print "op: ", op
# print "ord: ", [str(x) for x in self.orderings()[op]]
# print
def on_prune(self, op):
view_map, destroy_map = self.get_maps(op)
if destroy_map:
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
self.__detect_cycles__(output, True)
for i, output in enumerate(op.outputs):
try:
del self.parent[output]
except:
pass
del self.children[output]
def __add_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
if len(destroyers) > 1:
self.dups.add(foundation)
def __remove_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers[foundation]
del destroyers[op]
if not destroyers:
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
prev = set()
for op, i in clients:
prev.update(op.outputs)
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_2 in inputs:
assert self.parent.get(output, None) == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
self.__detect_cycles__(r_1)
self.__detect_cycles__(r_2)
def validate(self):
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
else:
return True
def orderings(self):
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
class PythonOp(Op):
__metaclass__ = ClsInit
__require__ = DestroyHandler
nout = 1
@staticmethod
......
import gof
from gof.lib import compute_from
import core
class Grad(object):
......@@ -19,9 +20,10 @@ class Grad(object):
def __init__(self, dct={}):
self.map = {}
self.outputs = []
self._compute_history = set([])
self.did_bprop = False
for key,val in dct.items():
self.add_output(key,val)
self.did_bprop = False
def __contains__(self, item):
return item in self.map
......@@ -60,8 +62,8 @@ class Grad(object):
pass
else:
if r.data is core.UNCOMPUTED or dr.data is core.UNCOMPUTED:
pass
else: # try some hacky checks to catch obvious mistakes
pass # no sanity checking
else: # some sanity checking to catch obvious mistakes
if not hasattr(r.data, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=',
type(r.data)))
......@@ -72,6 +74,10 @@ class Grad(object):
raise ValueError(('Grad::add r, dr shape mismatch',
r.data.shape, dr.data.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.data is not core.UNCOMPUTED:
self._compute_history.add(r)
# add dr to self[r]
if r in self:
self[r] = self[r] + dr
......@@ -120,7 +126,7 @@ class Grad(object):
rval = self[item]
if rval is not core.UNDEFINED \
and core.current_mode() == 'build_eval':
rval.compute()
compute_from([rval], self._compute_history)
return rval
def grad(cost, param=None, cost_grad = 1.0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论