提交 28aedba1 authored 作者: James Bergstra's avatar James Bergstra

Several related fixes:

- added allow_gc option to PerformLinker and OpWiseCLinker, which results in temporary results being freed as soon as possible. This can have a negative impact on performance (we have to re-allocate outputs) but allows us to evaluate graphs piecewise, even if not all temporaries fit in memory at once. - modified Function so that internal references to `required` input arguments and computed output arguments are removed after a call. This means that these values are not stored by pickle. - modified MergeOptimizer to use a more efficient algorithm. The previous one was at best quadratic in depth, and at worst exponential in depth. The new one should be worst-case linear.
上级 169a80b1
......@@ -265,7 +265,22 @@ class Function(object):
raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'result', self.inv_finder[c]))
# Do the actual work
self.fn()
# Retrieve the values that were computed
outputs = [x.data for x in self.output_storage]
#remove internal references to required inputs
#these can't be re-used anyway
for x in self.input_storage:
if c.required:
c.storage[0] = None
# if we are allowing garbage collection, remove the input and output reference from the internal
# storage cells
if getattr(self.fn, 'allow_gc', False):
for x in self.output_storage:
x.storage[0] = None #WARNING: This circumvents the 'readonly' attribute in x
# Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update:
......
......@@ -773,10 +773,12 @@ class OpWiseCLinker(link.LocalLinker):
def __init__(self,
fallback_on_perform = True,
allow_gc = True,
nice_errors = True):
self.env = None
self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors
self.allow_gc = allow_gc
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
......@@ -792,6 +794,11 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
if self.allow_gc:
computed, last_user = link.gc_helper(order)
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
thunks = []
for node in order:
......@@ -840,6 +847,11 @@ class OpWiseCLinker(link.LocalLinker):
else:
raise
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if (input in computed) and (input not in env.outputs) and node == last_user[input]])
if no_recycling is True:
no_recycling = storage_map.values()
no_recycling = utils.difference(no_recycling, input_storage)
......@@ -847,9 +859,12 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = link.streamline(env, thunks, order,
post_thunk_old_storage,
no_recycling = no_recycling,
nice_errors = self.nice_errors)
f.allow_gc = self.allow_gc
return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order
......
......@@ -111,18 +111,20 @@ class Linker(object):
return execute
#TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object):
"""WRITEME[Fred: fill the WRITEME I need it! Event if is partial.]
"""This class joins a result with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function.
"""
def __init__(self, r, storage, readonly = False, strict = False, name = None):
"""WRITEME
:Parameters:
`r`: a result
`storage`:
`readonly`:
`storage`: a list of length 1, whose element is the value for `r`
`readonly`: True indicates that this should not be setable by Function[r] = val
`strict`: if True, we don't allow type casting.
`name`:
`name`: A string (for pretty-printing?)
"""
if not isinstance(storage, list) or not len(storage) >= 1:
......@@ -226,7 +228,7 @@ def clear_storage_thunk(stg):
thunk.inputs = [stg]
return thunk
def streamline(env, thunks, order, no_recycling = [], profiler = None, nice_errors = True):
def streamline(env, thunks, order, post_thunk_old_storage = None, no_recycling = [], profiler = None, nice_errors = True):
"""WRITEME
:param env:
......@@ -235,6 +237,10 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None, nice_erro
:param order: the list of apply instances that gave rise to the thunks (same order as thunks)
:param post_thunk_old_storage: a list (corresponding to thunks, order) whose elements are
lists of storage cells, that should be cleared after running the corresponding thunk. A
value of None disables this functionality
:param no_recycling: storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
......@@ -246,8 +252,28 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None, nice_erro
if profiler is not None:
raise NotImplementedError()
if nice_errors:
def f():
if len(thunks) != len(order):
raise ValueError('Length of thunks and order must match',
(len(thunks), len(order)))
if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage):
raise ValueError('Length of thunks and post_thunk_old_storage must match',
(len(thunks), len(post_thunk_old_storage)))
def streamline_default_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node, old_storage in zip(thunks, order, post_thunk_old_storage):
thunk()
for old_s in old_storage:
old_s[0] = None
except:
raise_with_op(node)
f = streamline_default_f
elif nice_errors:
def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
try:
......@@ -255,14 +281,16 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None, nice_erro
thunk()
except:
raise_with_op(node)
f = streamline_nice_errors_f
else:
# don't worry about raise_with_op, just go a little faster.
#there is a mix of python and c thunks
def f():
def streamline_fast_f():
for x in no_recycling:
x[0] = None
for thunk in thunks:
thunk()
f = streamline_fast_f
return f
class LocalLinker(Linker):
......@@ -287,7 +315,26 @@ class LocalLinker(Linker):
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise AbstractFunctionError
def gc_helper(node_list):
"""
:param node_list: list of Apply instances in program execution order
:rtype: a 2-tuple
:returns: FIRST, the set of Result instances which are computed by node_list, and SECOND a
dictionary that maps each Result instance to a the last node to use Result as an input.
This is used to allow garbage collection within graphs.
"""
#for freeing memory
last_user = {}
computed = set()
for node in node_list:
for input in node.inputs:
last_user[input] = node
for output in node.outputs:
computed.add(output)
return computed, last_user
class PerformLinker(LocalLinker):
"""WRITEME
......@@ -295,7 +342,7 @@ class PerformLinker(LocalLinker):
the L{Env} in the order given by L{Env.toposort}.
"""
def __init__(self, allow_gc=False):
def __init__(self, allow_gc=True):
#TODO: set allow_gc = True by default, when it works with the OpWiseCLinker
self.env = None
self.allow_gc = allow_gc
......@@ -325,24 +372,19 @@ class PerformLinker(LocalLinker):
"""
env = self.env
order = env.toposort()
order = list(env.toposort())
no_recycling = self.no_recycling
thunks = []
new_order = []
input_storage, output_storage, storage_map = map_storage(env, order, input_storage, output_storage)
#for freeing memory
if self.allow_gc:
last_user = {}
computed = set()
for node in order:
for idx, input in enumerate(node.inputs):
last_user[input] = (node, idx)
for output in node.outputs:
computed.add(output)
computed, last_user = gc_helper(order)
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
for node in order:
node_input_storage = tuple(storage_map[input] for input in node.inputs)
node_output_storage = tuple(storage_map[output] for output in node.outputs)
......@@ -355,37 +397,40 @@ class PerformLinker(LocalLinker):
thunk.outputs = node_output_storage
thunk.perform = p
thunks.append(thunk)
new_order.append(node)
if self.allow_gc:
for idx, input in enumerate(node.inputs):
if input not in computed:
continue
if input in env.outputs:
continue
if (node, idx) == last_user[input]:
#print '... zeroing', id(storage_map[input])
thunks.append(clear_storage_thunk(storage_map[input]))
new_order.append(node)
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if (input in computed) and (input not in env.outputs) and node == last_user[input]])
if 0: # -JB 20081202
if self.allow_gc:
for idx, input in enumerate(node.inputs):
if input not in computed:
continue
if input in env.outputs:
continue
if (node, idx) == last_user[input]:
#print '... zeroing', id(storage_map[input])
thunks.append(clear_storage_thunk(storage_map[input]))
new_order.append(node)
if no_recycling is True:
#True is like some special code for *everything*.
#FunctionMaker always passes a list I think -JB
# True seems like some special code for *everything*?? -JB
# FunctionMaker always passes a list I think -JB
no_recycling = storage_map.values()
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
# The function that actually runs your program is one of the f's in streamline.
f = streamline(env, thunks, new_order, no_recycling = no_recycling, profiler = profiler)
f = streamline(env, thunks, order, post_thunk_old_storage, no_recycling = no_recycling, profiler = profiler)
f.allow_gc = self.allow_gc #HACK: this is a way of passing an arg to Function.__call__
return f, [Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, new_order
thunks, order
class WrapLinker(Linker):
""" WRITEME
......
......@@ -18,7 +18,7 @@ import sys
_optimizer_idx = [0]
def _list_of_nodes(env):
return graph.io_toposort(env.inputs, env.outputs)
return list(graph.io_toposort(env.inputs, env.outputs))
class Optimizer(object):
"""WRITEME
......@@ -195,7 +195,7 @@ class MergeOptimizer(Optimizer):
const_sig[c] = sig
const_sig_inv[sig] = c
def apply_node_merge(self, env):
def exptime_apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Results
......@@ -209,6 +209,7 @@ class MergeOptimizer(Optimizer):
for node in _list_of_nodes(env):
node_cid = (node.op, tuple([symbol_idx[input] for input in node.inputs]))
print 'NODE', node, node_cid
dup = symbol_idx_inv.get(node_cid, None)
success = False
if dup is not None:
......@@ -228,6 +229,52 @@ class MergeOptimizer(Optimizer):
ref = (i, node_cid)
symbol_idx[output] = ref
symbol_idx_inv[ref] = output
def apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Results
nodes_seen = set()
for node in _list_of_nodes(env):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients
merge_candidates = [c for (c,i) in node.inputs[0].clients if c in nodes_seen]
nodes_seen.add(node)
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate in merge_candidates:
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
assert node is not candidate
#
#transfer clients from node to candidate
#
success = True
assert len(node.outputs) == len(candidate.outputs)
pairs = zip(node.outputs, candidate.outputs)
#transfer names
for node_output, cand_output in pairs:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
try:
env.replace_all_validate(pairs)
except InconsistencyError, e:
success = False
if success:
#break out of the candidate loop
break
else:
#try the next candidate
pass
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def apply(self, env):
......
......@@ -48,10 +48,13 @@ class MyOp(Op):
return self.name
def __eq__(self, other):
return self is other or isinstance(other, MyOp) and self.x is not None and self.x == other.x
#rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x)
return rval
def __hash__(self):
return self.x if self.x is not None else id(self)
#return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
return hash(self.x if self.x is not None else id(self))
op1 = MyOp('Op1')
......@@ -238,7 +241,8 @@ class TestPatternOptimizer:
g = Env([x, y, z], [e])
PatternOptimizer((op1, (op_z, '1', '2'), '3'),
(op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]"
str_g = str(g)
assert str_g == "[Op4(z, y)]"
# def test_multi_ingraph(self):
# # known to fail
......
......@@ -358,8 +358,8 @@ def res_is_a(node, op, maxclients=None):
and (len(node.clients) <= maxclients if maxclients is not None else True)
class GemmLocalOptimizer(LocalOptimizer):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
"""This is a massive beast for recognizing all the ways that a subtraction or addition
could be replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
......
......@@ -123,6 +123,9 @@ class DimShuffle(Op):
if self.inplace:
self.view_map = {0: [0]}
self._hashval = hash(type(self)) ^ hash(self.inplace) \
^ hash(self.new_order) ^ hash(self.input_broadcastable)
def make_node(self, input):
ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable:
......@@ -146,8 +149,7 @@ class DimShuffle(Op):
and self.input_broadcastable == other.input_broadcastable
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace) \
^ hash(self.new_order) ^ hash(self.input_broadcastable)
return self._hashval
def __str__(self):
if self.inplace:
......@@ -327,6 +329,12 @@ class Elemwise(Op):
else:
self.ufunc = None
#precompute the hash of this node
items = self.inplace_pattern.items()
items.sort()
tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items])
self._hashval = hash(self.scalar_op) ^ hash(tuple_items)
def __getstate__(self):
d = copy(self.__dict__)
d.pop('ufunc')
......@@ -399,9 +407,7 @@ class Elemwise(Op):
return False
def __hash__(self):
items = self.inplace_pattern.items()
items.sort()
return hash(self.scalar_op) ^ hash(tuple(items))
return self._hashval
def __str__(self):
if self.name is None:
......
......@@ -832,66 +832,6 @@ def constant_folding(node):
register_canonicalize(constant_folding)
from blas import _dot22
@gof.local_optimizer([T.sub])
def local_sub_to_gemm(node):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
"""
if node.op == T.sub:
subleft, subright = node.inputs
#EXPRESSION: subleft - subright
if subright.owner and (subright.owner.op == _dot22):
dotleft, dotright = subright.owner.inputs
return [T.gemm(subleft, -1.0, dotleft, dotright, 1.0)]
if subright.owner and (subright.owner.op == T.mul):
mulleft, mulright = subright.owner.inputs
#EXPRESSION: subleft - (mulleft * mulright)
#TODO: we actually want to get any scalar here, not necessrily a constant
mulleft_const = local_mul_canonizer.get_constant(mulleft)
if mulleft_const is not None and mulleft_const.size == 1:
mulleft_const = mulleft_const.flatten()[0]
#EXPRESSION: subleft - (mulleft_const * ?)
if mulright.owner and (mulright.owner.op == T.add):
#EXPRESSION: subleft - (mulleft_const * (? + ?))
addleft, addright = mulright.owner.inputs
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0]):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
#raise NotImplementedError()
return False
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0], inplace=True):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
transposed = addright.owner.inputs[0]
if transposed.owner and transposed.owner.op == _dot22:
x, y = transposed.owner.inputs
#EXPRESSION: subleft - (mulleft_const * (addleft + dot(x, y).T))
if addleft.owner and addleft.owner.op == _dot22:
u, v = addleft.owner.inputs
#EXPRESSION: subleft - (mulleft_const * (dot(u,v) + dot(x, y).T))
return [T.gemm(
T.gemm(subleft, -mulleft_const, y.T, x.T, 1.0),
-mulleft_const, u, v, 1.0)]
if mulright.owner and (mulright.owner.op == _dot22):
dotleft, dotright = mulright.owner.inputs
#EXPRESSION: subleft - (mulleft_const * dot(dotleft, dotright))
return [T.gemm(subleft, -mulleft_const, dotleft, dotright, 1.0)]
mulright_const = local_mul_canonizer.get_constant(mulright)
if mulright_const is not None and mulright_const.size == 1:
mulright_const = mulright_const.flatten()[0]
#EXPRESSION: subleft - (? * mulright_const)
if mulleft.owner and (mulleft.owner.op == _dot22):
dotleft, dotright = mulleft.owner.inputs
#EXPRESSION: subleft - (dot(dotleft, dotright) * mulright_const)
return [T.gemm(subleft, -mulright_const, dotleft, dotright, 1.0)]
return False
register_specialize(local_sub_to_gemm)
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (inplace_matrix_transpose, 'y'), (inplace_matrix_transpose, 'x')))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论