提交 3b02a1f4 authored 作者: James Bergstra's avatar James Bergstra

many changes related to optimizations, deterministic ordering of optimizations, and benchmarking

上级 3f2298bd
aa.x : aa.cc aa.x : aa.cc
g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lcblas -lgoto -lgfortran -lm g++ -O3 -ffast-math -ftree-vectorize aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lmkl
#g++ aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lmkl
clean : clean :
rm aa.x rm aa.x
...@@ -28,6 +28,7 @@ int main(int argc, char **argv) ...@@ -28,6 +28,7 @@ int main(int argc, char **argv)
int neg = strtol(argv[1], 0, 0); int neg = strtol(argv[1], 0, 0);
int nout = strtol(argv[2], 0, 0); int nout = strtol(argv[2], 0, 0);
int nin = nout;
int nhid = strtol(argv[3], 0, 0); int nhid = strtol(argv[3], 0, 0);
int niter = strtol(argv[4], 0, 0); int niter = strtol(argv[4], 0, 0);
double lr = 0.01; double lr = 0.01;
...@@ -35,8 +36,8 @@ int main(int argc, char **argv) ...@@ -35,8 +36,8 @@ int main(int argc, char **argv)
gsl_rng_set(rng, 234); gsl_rng_set(rng, 234);
gsl_matrix * x = gsl_matrix_alloc(neg, nout); gsl_matrix * x = gsl_matrix_alloc(neg, nin);
gsl_matrix * w = gsl_matrix_alloc(nout, nhid); gsl_matrix * w = gsl_matrix_alloc(nin, nhid);
gsl_vector * a = gsl_vector_alloc(nhid); gsl_vector * a = gsl_vector_alloc(nhid);
gsl_vector * b = gsl_vector_alloc(nout); gsl_vector * b = gsl_vector_alloc(nout);
gsl_matrix * xw = gsl_matrix_alloc(neg, nhid); gsl_matrix * xw = gsl_matrix_alloc(neg, nhid);
...@@ -59,11 +60,17 @@ int main(int argc, char **argv) ...@@ -59,11 +60,17 @@ int main(int argc, char **argv)
struct timeval tv0, tv1; struct timeval tv0, tv1;
struct timeval tdot0, tdot1;
double time_of_dot = 0.0;
gettimeofday(&tv0, 0); gettimeofday(&tv0, 0);
double err = 0.0; double err = 0.0;
for (int iter = 0; iter < niter; ++iter) for (int iter = 0; iter < niter; ++iter)
{ {
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw); gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i) for (int i = 0; i < neg; ++i)
for (int j = 0; j < nhid; ++j) for (int j = 0; j < nhid; ++j)
...@@ -72,7 +79,10 @@ int main(int argc, char **argv) ...@@ -72,7 +79,10 @@ int main(int argc, char **argv)
hid->data[i*nhid+j] = tanh(act); hid->data[i*nhid+j] = tanh(act);
} }
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt); gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0; for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0;
err = 0.0; err = 0.0;
...@@ -90,8 +100,11 @@ int main(int argc, char **argv) ...@@ -90,8 +100,11 @@ int main(int argc, char **argv)
if (1) if (1)
{ {
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid); gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w); gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i) for (int i = 0; i < neg; ++i)
...@@ -101,14 +114,19 @@ int main(int argc, char **argv) ...@@ -101,14 +114,19 @@ int main(int argc, char **argv)
a->data[j] -= lr * g_hid->data[i*nhid+j]; a->data[j] -= lr * g_hid->data[i*nhid+j];
} }
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w); gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i]; for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i];
} }
} }
gettimeofday(&tv1, 0); gettimeofday(&tv1, 0);
fprintf(stdout, "took = %lfs to get err %lf\n", pytime(&tv1) - pytime(&tv0), 0.5 * err); double total_time = pytime(&tv1) - pytime(&tv0);
fprintf(stdout, "took = %lfs to get err %lf\n", total_time, 0.5 * err);
fprintf(stdout, "... of which %.2lfs was spent in dgemm (fraction: %.2lf)\n", time_of_dot, time_of_dot / total_time);
//skip freeing //skip freeing
return 0; return 0;
} }
......
...@@ -11,6 +11,9 @@ import theano.sandbox.wraplinker ...@@ -11,6 +11,9 @@ import theano.sandbox.wraplinker
from theano.compile import module, Mode from theano.compile import module, Mode
from theano.sandbox.wraplinker import ProfileMode from theano.sandbox.wraplinker import ProfileMode
# numpy: aa_numpy.py
# c : aa.cc
if 0: if 0:
class Opt(object): class Opt(object):
merge = theano.gof.MergeOptimizer() merge = theano.gof.MergeOptimizer()
...@@ -131,7 +134,7 @@ if 0: ...@@ -131,7 +134,7 @@ if 0:
self.merge(env) self.merge(env)
def linker(print_prog=True): def print_graph_linker(print_prog=True):
if 1: if 1:
imap = {None:'-'} imap = {None:'-'}
def blah(i, node, thunk): def blah(i, node, thunk):
...@@ -146,7 +149,6 @@ def linker(print_prog=True): ...@@ -146,7 +149,6 @@ def linker(print_prog=True):
print 'node ', i, node, print 'node ', i, node,
print ':'.join([imap[inp.owner] for inp in node.inputs]) print ':'.join([imap[inp.owner] for inp in node.inputs])
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs) #print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
return theano.sandbox.wraplinker.WrapLinkerMany( return theano.sandbox.wraplinker.WrapLinkerMany(
[theano.gof.OpWiseCLinker()], [theano.gof.OpWiseCLinker()],
[theano.sandbox.wraplinker.run_all [theano.sandbox.wraplinker.run_all
...@@ -184,8 +186,9 @@ class M(module.Module): ...@@ -184,8 +186,9 @@ class M(module.Module):
self.step = module.Method([x], err, updates=dict(updates)) self.step = module.Method([x], err, updates=dict(updates))
mod = M() mod = M()
#m = mod.make(mode='FAST_RUN') #mode = 'FAST_RUN'
mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker()) mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
print mod.pretty(mode=mode)
m = mod.make(mode=mode) m = mod.make(mode=mode)
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]] neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
...@@ -200,5 +203,10 @@ t = time.time() ...@@ -200,5 +203,10 @@ t = time.time()
for i in xrange(niter): for i in xrange(niter):
err = m.step(x) err = m.step(x)
print 'time: ',time.time() - t, 'err: ', err print 'time: ',time.time() - t, 'err: ', err
mode.print_summary() try:
mode.print_summary()
pass
except:
raise
...@@ -4,6 +4,8 @@ import numpy as N ...@@ -4,6 +4,8 @@ import numpy as N
import sys import sys
import time import time
# c: aa.cc
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]] neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
lr = 0.01 lr = 0.01
...@@ -14,12 +16,20 @@ a = rng.randn(nhid) * 0.0 ...@@ -14,12 +16,20 @@ a = rng.randn(nhid) * 0.0
b = rng.randn(nout) * 0.0 b = rng.randn(nout) * 0.0
x = (rng.rand(neg, nout)-0.5) * 1.5 x = (rng.rand(neg, nout)-0.5) * 1.5
dot_time = 0.0
t = time.time() t = time.time()
for i in xrange(niter): for i in xrange(niter):
hid = N.tanh(N.dot(x, w) + a) tt = time.time()
d = N.dot(x, w)
dot_time += time.time() - tt
hid = N.tanh(d + a)
out = N.tanh(N.dot(hid, w.T) + b) tt = time.time()
d = N.dot(hid, w.T)
dot_time += time.time() - tt
out = N.tanh(d + b)
g_out = out - x g_out = out - x
err = 0.5 * N.sum(g_out**2) err = 0.5 * N.sum(g_out**2)
...@@ -28,12 +38,23 @@ for i in xrange(niter): ...@@ -28,12 +38,23 @@ for i in xrange(niter):
b -= lr * N.sum(g_hidwt, axis=0) b -= lr * N.sum(g_hidwt, axis=0)
tt = time.time()
g_hid = N.dot(g_hidwt, w) g_hid = N.dot(g_hidwt, w)
dot_time += time.time() - tt
g_hidin = g_hid * (1.0 - hid**2) g_hidin = g_hid * (1.0 - hid**2)
w -= lr * (N.dot(g_hidwt.T, hid) + N.dot(x.T, g_hidin)) tt = time.time()
d = N.dot(g_hidwt.T, hid)
dd = N.dot(x.T, g_hidin)
dot_time += time.time() - tt
gw = (d + dd)
w -= lr * gw
a -= lr * N.sum(g_hidin, axis=0) a -= lr * N.sum(g_hidin, axis=0)
print 'time: ',time.time() - t, 'err: ', err total_time = time.time() - t
print 'time: ',total_time, 'err: ', err
print ' of which', dot_time, 'was spent on dot. Fraction:', dot_time / total_time
...@@ -23,11 +23,12 @@ from op import \ ...@@ -23,11 +23,12 @@ from op import \
from opt import \ from opt import \
Optimizer, optimizer, SeqOptimizer, \ Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \ MergeOptimizer, MergeOptMerge, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \ NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, \
keep_going, warn, \ keep_going, warn, \
InplaceOptimizer, PureThenInplaceOptimizer InplaceOptimizer, PureThenInplaceOptimizer
#LocalOpKeyOptGroup, OpKeyOptimizer
from optdb import \ from optdb import \
DB, Query, \ DB, Query, \
......
...@@ -736,6 +736,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage): ...@@ -736,6 +736,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
else: else:
return tasks[failure_code - n] return tasks[failure_code - n]
def execute(): def execute():
execute.cthunk = cthunk
failure = cutils.run_cthunk(cthunk) failure = cutils.run_cthunk(cthunk)
if failure: if failure:
task, taskname, id = find_task(failure) task, taskname, id = find_task(failure)
......
...@@ -13,6 +13,7 @@ from collections import deque ...@@ -13,6 +13,7 @@ from collections import deque
import utils import utils
_creation_idx = [0]
class Apply(utils.object2): class Apply(utils.object2):
""" """
...@@ -121,6 +122,13 @@ class Apply(utils.object2): ...@@ -121,6 +122,13 @@ class Apply(utils.object2):
def __asapply__(self): def __asapply__(self):
return self return self
def __hash__(self):
if not hasattr(self, '_creation_idx'):
self._creation_idx = _creation_idx[0]
_creation_idx[0] += 1
return self._creation_idx
def clone(self): def clone(self):
"""Duplicate this Apply instance with inputs = self.inputs. """Duplicate this Apply instance with inputs = self.inputs.
...@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False):
deps(i) should behave like a pure function (no funny business with internal state) deps(i) should behave like a pure function (no funny business with internal state)
:note: :note:
deps(i) can/should be cached by the deps function to be fast deps(i) will be cached by this function (to be fast)
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
""" """
deps_cache = {} deps_cache = {}
def _deps(io): def _deps(io):
...@@ -611,6 +622,7 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -611,6 +622,7 @@ def general_toposort(r_out, deps, debug_print = False):
def io_toposort(i, o, orderings = {}): def io_toposort(i, o, orderings = {}):
"""WRITEME """WRITEME
""" """
#the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(i) iset = set(i)
def deps(obj): def deps(obj):
rval = [] rval = []
......
...@@ -17,6 +17,9 @@ import sys ...@@ -17,6 +17,9 @@ import sys
_optimizer_idx = [0] _optimizer_idx = [0]
def _list_of_nodes(env):
return graph.io_toposort(env.inputs, env.outputs)
class Optimizer(object): class Optimizer(object):
"""WRITEME """WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it. An L{Optimizer} can be applied to an L{Env} to transform it.
...@@ -73,7 +76,7 @@ class FromFunctionOptimizer(Optimizer): ...@@ -73,7 +76,7 @@ class FromFunctionOptimizer(Optimizer):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def optimizer(f): def optimizer(f):
"""WRITEME""" """decorator for FromFunctionOptimizer"""
return FromFunctionOptimizer(f) return FromFunctionOptimizer(f)
...@@ -137,6 +140,10 @@ class _metadict: ...@@ -137,6 +140,10 @@ class _metadict:
try: try:
self.d[item] = value self.d[item] = value
except: except:
for i, (key,val) in enumerate(self.l):
if key == item:
self.l[i] = (item, value)
return
self.l.append((item, value)) self.l.append((item, value))
def get(self, item, default): def get(self, item, default):
try: try:
...@@ -191,7 +198,7 @@ class MergeOptimizer(Optimizer): ...@@ -191,7 +198,7 @@ class MergeOptimizer(Optimizer):
cid[r] = i cid[r] = i
inv_cid[i] = r inv_cid[i] = r
for node in graph.io_toposort(env.inputs, env.outputs): for node in _list_of_nodes(env):
node_cid = (node.op, tuple([cid[input] for input in node.inputs])) node_cid = (node.op, tuple([cid[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None) dup = inv_cid.get(node_cid, None)
success = False success = False
...@@ -229,10 +236,33 @@ def MergeOptMerge(opt): ...@@ -229,10 +236,33 @@ def MergeOptMerge(opt):
### Local Optimizers ### ### Local Optimizers ###
######################## ########################
class LocalOptimizer(Optimizer, utils.object2): class LocalOptimizer(object):
"""WRITEME""" """A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a env-based Optimizer instance.
"""
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def transform(self, node): def transform(self, node):
"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of results> to use in place of `node`'s outputs in the greater graph.
:type node: an Apply instance
"""
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
...@@ -272,7 +302,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -272,7 +302,7 @@ class LocalOptGroup(LocalOptimizer):
return repl return repl
class LocalOpKeyOptGroup(LocalOptGroup): class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME""" """WRITEME"""
def __init__(self, optimizers): def __init__(self, optimizers):
...@@ -515,9 +545,29 @@ class PatternSub(LocalOptimizer): ...@@ -515,9 +545,29 @@ class PatternSub(LocalOptimizer):
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(Optimizer):
"""WRITEME""" """Abstract class
"""
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
"""
:param local_opt: a LocalOptimizer to apply over a Env.
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
"""
self.local_opt = local_opt self.local_opt = local_opt
if ignore_newtrees == 'auto': if ignore_newtrees == 'auto':
self.ignore_newtrees = not getattr(local_opt, 'reentrant', True) self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
...@@ -526,6 +576,15 @@ class NavigatorOptimizer(Optimizer): ...@@ -526,6 +576,15 @@ class NavigatorOptimizer(Optimizer):
self.failure_callback = failure_callback self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner, chin = None): def attach_updater(self, env, importer, pruner, chin = None):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later!
"""
if self.ignore_newtrees: if self.ignore_newtrees:
importer = None importer = None
...@@ -548,6 +607,12 @@ class NavigatorOptimizer(Optimizer): ...@@ -548,6 +607,12 @@ class NavigatorOptimizer(Optimizer):
return u return u
def detach_updater(self, env, u): def detach_updater(self, env, u):
"""Undo the work of attach_updater.
:param u: a return-value of attach_updater
:returns: None.
"""
if u is not None: if u is not None:
env.remove_feature(u) env.remove_feature(u)
...@@ -610,7 +675,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -610,7 +675,7 @@ class TopoOptimizer(NavigatorOptimizer):
except: except:
self.detach_updater(env, u) self.detach_updater(env, u)
raise raise
self.detach_updater(env, u)
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
...@@ -642,6 +707,7 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -642,6 +707,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
except: except:
self.detach_updater(env, u) self.detach_updater(env, u)
raise raise
self.detach_updater(env, u)
def add_requirements(self, env): def add_requirements(self, env):
""" """
...@@ -654,38 +720,70 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -654,38 +720,70 @@ class OpKeyOptimizer(NavigatorOptimizer):
# class EquilibriumOptimizer(NavigatorOptimizer):
# """WRITEME"""
# def __init__(self, local_optimizers, failure_callback = None):
# NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
# def apply(self, env):
# op = self.local_opt.op_key()
# if isinstance(op, (list, tuple)):
# q = reduce(list.__iadd__, map(env.get_nodes, op))
# else:
# q = list(env.get_nodes(op))
# def importer(node):
# if node.op == op: q.append(node)
# def pruner(node):
# if node is not current_node and node.op == op:
# try: q.remove(node)
# except ValueError: pass
# u = self.attach_updater(env, importer, pruner)
# try:
# while q:
# node = q.pop()
# current_node = node
# self.process_node(env, node)
# except:
# self.detach_updater(env, u)
# raise
from utils import D from utils import D
class EquilibriumOptimizer(NavigatorOptimizer): class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
"""
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
"""
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = True,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
def apply(self, env, start_from = None):
if start_from is None:
start_from = env.outputs
changed = True
max_use_abort = False
process_count = {}
while changed and not max_use_abort:
changed = False
q = deque(graph.io_toposort(env.inputs, start_from))
max_use = len(q) * self.max_use_ratio
def importer(node):
q.append(node)
def pruner(node):
if node is not current_node:
try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner)
try:
while q:
node = q.pop()
current_node = node
for lopt in self.local_optimizers:
process_count.setdefault(lopt, 0)
if process_count[lopt] > max_use:
max_use_abort = True
else:
lopt_change = self.process_node(env, node, lopt)
process_count[lopt] += 1 if lopt_change else 0
changed |= lopt_change
except:
self.detach_updater(env, u)
raise
self.detach_updater(env, u)
if max_use_abort:
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self, def __init__(self,
local_optimizers, local_optimizers,
...@@ -780,10 +878,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -780,10 +878,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
# importer(node) # importer(node)
for node in env.nodes: for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op)) tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin) u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks: while tasks:
for node in tasks.iterkeys(): for node in tasks.iterkeys():
todo = tasks.pop(node) todo = tasks.pop(node)
......
...@@ -18,7 +18,7 @@ class DB(object): ...@@ -18,7 +18,7 @@ class DB(object):
# N.B. obj is not an instance of class Optimizer. # N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example, # It is an instance of a DB.In the tests for example,
# this is not always the case. # this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer)): if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise Exception('wtf', obj) raise Exception('wtf', obj)
obj.name = name obj.name = name
......
...@@ -375,7 +375,7 @@ class TestEquilibrium(object): ...@@ -375,7 +375,7 @@ class TestEquilibrium(object):
x, y, z = map(MyResult, 'xyz') x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
print g print 'before', g
sys.stderr = sys.stdout # display pesky warnings along with stdout sys.stderr = sys.stdout # display pesky warnings along with stdout
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')), [PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
...@@ -384,7 +384,7 @@ class TestEquilibrium(object): ...@@ -384,7 +384,7 @@ class TestEquilibrium(object):
], ],
max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once
opt.optimize(g) opt.optimize(g)
print g print 'after', g
assert str(g) == '[Op4(x, y)]' assert str(g) == '[Op4(x, y)]'
......
...@@ -2,6 +2,7 @@ import gof #, gof.result ...@@ -2,6 +2,7 @@ import gof #, gof.result
import numpy #for numeric_grad import numpy #for numeric_grad
from gof.python25 import all from gof.python25 import all
import gof.utils
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
...@@ -55,7 +56,7 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -55,7 +56,7 @@ def grad_sources_inputs(sources, graph_inputs):
else: else:
gmap[r] = g_r gmap[r] = g_r
graph_outputs = gmap.keys() graph_outputs = gof.utils.uniq([r for r,g in sources])
if graph_inputs is None: if graph_inputs is None:
graph_inputs = gof.graph.inputs(graph_outputs) graph_inputs = gof.graph.inputs(graph_outputs)
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
import time import time
import numpy import numpy
from ..gof.cutils import run_cthunk
from ..gof.link import WrapLinker from ..gof.link import WrapLinker
from ..compile.mode import Mode from ..compile.mode import Mode
...@@ -107,19 +108,42 @@ class ProfileMode(Mode): ...@@ -107,19 +108,42 @@ class ProfileMode(Mode):
local_time = [0.0] local_time = [0.0]
apply_time = {} apply_time = {}
op_time = {} op_time = {}
op_cimpl = {}
def blah(i, node, *thunks): def blah(i, node, *thunks):
if 0:
t0 = time.time() t0 = time.time()
for th in thunks: for th in thunks:
th() th()
dt = time.time() - t0 dt = time.time() - t0
elif 0: #more precise timing
for th in thunks:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
for th in thunks:
if hasattr(th, 'cthunk'):
t0 = time.time()
run_cthunk(th.cthunk)
dt = time.time() - t0
else:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
pass
else:
raise Exception('one of the cases has to run the thunks!')
local_time[0] += dt local_time[0] += dt
apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt
op_time[node.op] = op_time.get(node.op, 0.0) + dt op_time[node.op] = op_time.get(node.op, 0.0) + dt
op_cimpl[node.op] = hasattr(thunks[0], 'cthunk')
self.local_time = local_time self.local_time = local_time
self.apply_time = apply_time self.apply_time = apply_time
self.op_time = op_time self.op_time = op_time
self.op_cimpl = op_cimpl
wrap_linker = WrapLinkerMany([linker], [blah]) wrap_linker = WrapLinkerMany([linker], [blah])
if optimizer: if optimizer:
...@@ -142,13 +166,18 @@ class ProfileMode(Mode): ...@@ -142,13 +166,18 @@ class ProfileMode(Mode):
atimes.sort() atimes.sort()
atimes.reverse() atimes.reverse()
for t,a in atimes[:15]: for t,a in atimes[:15]:
print ' ', t, a print '\t%.3f\t%i\t%s' % (t, a[0], a[1])
print ' ... (ignoring %i other Apply instances)'%max(0, len(atimes)-15) print ' ... (remaining %i Apply instances account for %.2f of the runtime)'\
%(max(0, len(atimes)-15), sum(t for t, a in atimes[15:]))
print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>' print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>'
otimes = [(t/local_time, a) for a, t in op_time.items()] otimes = [(t/local_time, a, self.op_cimpl[a]) for a, t in op_time.items()]
otimes.sort() otimes.sort()
otimes.reverse() otimes.reverse()
for t,a in otimes[:15]: for t,a,ci in otimes[:15]:
print ' ', t, a print '\t%.3f\t%s %s' % (t, '*' if ci else ' ', a)
print ' ... (ignoring %i other kinds Ops)'%max(0, len(otimes)-15) print ' ... (remaining %i Ops account for %.2f of the runtime)'\
%(max(0, len(otimes)-15), sum(t for t, a, ci in otimes[15:]))
print '(*) Op is running a c implementation'
...@@ -1089,7 +1089,8 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -1089,7 +1089,8 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations # View Operations
########################## ##########################
class TransposeInplace(Op): if 0:
class _TransposeInplace(Op):
view_map = {0: [0]} view_map = {0: [0]}
def make_node(self, input): def make_node(self, input):
...@@ -1114,12 +1115,16 @@ class TransposeInplace(Op): ...@@ -1114,12 +1115,16 @@ class TransposeInplace(Op):
def __str__(self): def __str__(self):
return "TransposeView" return "TransposeView"
_transpose_inplace = TransposeInplace() _transpose_inplace = _TransposeInplace()
def transpose(x, **kwargs): def _old_transpose(x, **kwargs):
"""WRITEME""" """WRITEME"""
return _transpose_inplace(tensor_copy(x), **kwargs) return _transpose_inplace(tensor_copy(x), **kwargs)
def transpose(x, **kwargs):
dims = range(x.ndim-1, -1, -1)
return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
......
...@@ -181,7 +181,7 @@ class DimShuffle(Op): ...@@ -181,7 +181,7 @@ class DimShuffle(Op):
for i, v in enumerate(self.new_order): for i, v in enumerate(self.new_order):
if v != 'x': if v != 'x':
grad_order[v] = i grad_order[v] = i
return DimShuffle(gz.type.broadcastable, grad_order)(gz), return [DimShuffle(gz.type.broadcastable, grad_order, inplace=True)(Elemwise(scalar.identity)(gz))]
......
from basic import _scal_elemwise, _transpose_inplace from basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal from .. import scalar as scal
import elemwise import elemwise
from .. import printing from .. import printing
...@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left')) ...@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right')) pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right'))
transpose_inplace = _transpose_inplace def transpose_inplace(x, **kwargs):
"""WRITEME""" """Perform a transpose on a tensor without copying the underlying storage"""
dims = range(x.ndim-1, -1, -1)
return elemwise.DimShuffle(x.broadcastable, dims, inplace=True)(x)
pprint.assign(transpose_inplace, printing.MemberPrinter('T')) #pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
...@@ -53,6 +53,7 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'), ...@@ -53,6 +53,7 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
allow_multiple_clients = False) allow_multiple_clients = False)
def _insert_inplace_optimizer(env): def _insert_inplace_optimizer(env):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_optimizer.optimize(env)
...@@ -216,6 +217,13 @@ register_canonicalize(local_shape_lift_dot) ...@@ -216,6 +217,13 @@ register_canonicalize(local_shape_lift_dot)
################ ################
def encompasses_broadcastable(b1, b2): def encompasses_broadcastable(b1, b2):
"""
Returns True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite.
:param b1: the broadcastable attribute of a tensor type
:param b2: the broadcastable attribute of a tensor type
"""
if len(b1) < len(b2): if len(b1) < len(b2):
return False return False
b1 = b1[-len(b2):] b1 = b1[-len(b2):]
...@@ -330,6 +338,7 @@ def local_fill_cut(node): ...@@ -330,6 +338,7 @@ def local_fill_cut(node):
register_canonicalize(local_fill_cut) register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' )
@gof.local_optimizer([None, T.fill]) @gof.local_optimizer([None, T.fill])
def local_fill_sink(node): def local_fill_sink(node):
...@@ -550,6 +559,7 @@ def local_neg_to_mul(node): ...@@ -550,6 +559,7 @@ def local_neg_to_mul(node):
return [-1 * node.inputs[0]] return [-1 * node.inputs[0]]
else: else:
return False return False
register_canonicalize(local_neg_to_mul)
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_to_neg(node): def local_mul_to_neg(node):
...@@ -557,6 +567,7 @@ def local_mul_to_neg(node): ...@@ -557,6 +567,7 @@ def local_mul_to_neg(node):
return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])] return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])]
else: else:
return False return False
register_specialize(local_mul_to_neg)
@gof.local_optimizer([T.div]) @gof.local_optimizer([T.div])
def local_div_to_inv(node): def local_div_to_inv(node):
...@@ -564,10 +575,57 @@ def local_div_to_inv(node): ...@@ -564,10 +575,57 @@ def local_div_to_inv(node):
return [T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))] return [T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))]
else: else:
return False return False
register_canonicalize(local_neg_to_mul)
register_specialize(local_mul_to_neg)
register_specialize(local_div_to_inv) register_specialize(local_div_to_inv)
@gof.local_optimizer([T.inv])
def local_inv_canon(node):
if node.op == T.inv:
return [T.pow(node.inputs[0], -1.0)]
else:
return False
register_canonicalize(local_inv_canon)
@gof.local_optimizer([T.pow])
def local_pow_canonicalize(node):
if node.op == T.pow:
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1.0):
return [T.fill(node.inputs[1], node.inputs[0])]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0.0):
#extra fills here are to make sure the size of the output stays constant.
return [T.fill(node.inputs[0], T.fill(node.inputs[1], 1.0))]
else:
return False
register_canonicalize(local_pow_canonicalize)
@gof.local_optimizer([T.pow])
def local_pow_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.pow:
#the idea here is that we have pow(x, y)
xsym = node.inputs[0]
ysym = node.inputs[1]
y = local_mul_canonizer.get_constant(ysym)
if (y is not None) \
and encompasses_broadcastable(xsym.type.broadcastable, ysym.type.broadcastable):
if N.all(y == 2.0):
return [T.sqr(xsym)]
if N.all(y == 1.0):
return [xsym]
if N.all(y == 0.0):
return [T.fill(xsym, 1.0)]
if N.all(y == 0.5):
return [T.sqrt(xsym)]
if N.all(y == -0.5):
return [T.inv(T.sqrt(xsym))]
if N.all(y == -1.0):
return [T.inv(xsym)]
if N.all(y == -2.0):
return [T.inv(T.sqr(xsym))]
else:
return False
register_specialize(local_pow_specialize)
register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer') register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
......
...@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase):
self.failUnless(i.shape == (2,3)) self.failUnless(i.shape == (2,3))
class T_transpose(unittest.TestCase):
def test0(self):
n = as_tensor(numpy.ones(()))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data == 1.0)
def test1(self):
n = as_tensor(numpy.ones(5))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data[0] == 1.0)
def test2(self):
n = as_tensor(numpy.ones((5,3)))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0] == 1.0)
def test3(self):
"""Test transpose of tensor, inplace version"""
n = as_tensor(numpy.ones((5,3,2)))
t = inplace.transpose_inplace(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (2,3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0,0] == 56.0)
def test_grad(self):
verify_grad(self, inplace.transpose_inplace, [numpy.random.rand(2, 3)])
verify_grad(self, inplace.transpose_inplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def setUp(self): def setUp(self):
Subtensor.debug = False Subtensor.debug = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论