提交 7bac82c1 authored 作者: Frederic Bastien's avatar Frederic Bastien
aa.x : aa.cc aa.x : aa.cc
g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lcblas -lgoto -lgfortran -lm g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl ${THEANO_BLAS_LDFLAGS}
clean : clean :
rm aa.x rm aa.x
...@@ -28,6 +28,7 @@ int main(int argc, char **argv) ...@@ -28,6 +28,7 @@ int main(int argc, char **argv)
int neg = strtol(argv[1], 0, 0); int neg = strtol(argv[1], 0, 0);
int nout = strtol(argv[2], 0, 0); int nout = strtol(argv[2], 0, 0);
int nin = nout;
int nhid = strtol(argv[3], 0, 0); int nhid = strtol(argv[3], 0, 0);
int niter = strtol(argv[4], 0, 0); int niter = strtol(argv[4], 0, 0);
double lr = 0.01; double lr = 0.01;
...@@ -35,8 +36,8 @@ int main(int argc, char **argv) ...@@ -35,8 +36,8 @@ int main(int argc, char **argv)
gsl_rng_set(rng, 234); gsl_rng_set(rng, 234);
gsl_matrix * x = gsl_matrix_alloc(neg, nout); gsl_matrix * x = gsl_matrix_alloc(neg, nin);
gsl_matrix * w = gsl_matrix_alloc(nout, nhid); gsl_matrix * w = gsl_matrix_alloc(nin, nhid);
gsl_vector * a = gsl_vector_alloc(nhid); gsl_vector * a = gsl_vector_alloc(nhid);
gsl_vector * b = gsl_vector_alloc(nout); gsl_vector * b = gsl_vector_alloc(nout);
gsl_matrix * xw = gsl_matrix_alloc(neg, nhid); gsl_matrix * xw = gsl_matrix_alloc(neg, nhid);
...@@ -59,11 +60,17 @@ int main(int argc, char **argv) ...@@ -59,11 +60,17 @@ int main(int argc, char **argv)
struct timeval tv0, tv1; struct timeval tv0, tv1;
struct timeval tdot0, tdot1;
double time_of_dot = 0.0;
gettimeofday(&tv0, 0); gettimeofday(&tv0, 0);
double err = 0.0; double err = 0.0;
for (int iter = 0; iter < niter; ++iter) for (int iter = 0; iter < niter; ++iter)
{ {
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw); gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i) for (int i = 0; i < neg; ++i)
for (int j = 0; j < nhid; ++j) for (int j = 0; j < nhid; ++j)
...@@ -72,7 +79,10 @@ int main(int argc, char **argv) ...@@ -72,7 +79,10 @@ int main(int argc, char **argv)
hid->data[i*nhid+j] = tanh(act); hid->data[i*nhid+j] = tanh(act);
} }
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt); gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0; for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0;
err = 0.0; err = 0.0;
...@@ -90,8 +100,11 @@ int main(int argc, char **argv) ...@@ -90,8 +100,11 @@ int main(int argc, char **argv)
if (1) if (1)
{ {
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid); gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w); gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i) for (int i = 0; i < neg; ++i)
...@@ -101,14 +114,19 @@ int main(int argc, char **argv) ...@@ -101,14 +114,19 @@ int main(int argc, char **argv)
a->data[j] -= lr * g_hid->data[i*nhid+j]; a->data[j] -= lr * g_hid->data[i*nhid+j];
} }
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w); gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i]; for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i];
} }
} }
gettimeofday(&tv1, 0); gettimeofday(&tv1, 0);
fprintf(stdout, "took = %lfs to get err %lf\n", pytime(&tv1) - pytime(&tv0), 0.5 * err); double total_time = pytime(&tv1) - pytime(&tv0);
fprintf(stdout, "took = %lfs to get err %lf\n", total_time, 0.5 * err);
fprintf(stdout, "... of which %.2lfs was spent in dgemm (fraction: %.2lf)\n", time_of_dot, time_of_dot / total_time);
//skip freeing //skip freeing
return 0; return 0;
} }
......
...@@ -10,6 +10,13 @@ import theano.sandbox ...@@ -10,6 +10,13 @@ import theano.sandbox
import theano.sandbox.wraplinker import theano.sandbox.wraplinker
from theano.compile import module, Mode from theano.compile import module, Mode
from theano.sandbox.wraplinker import ProfileMode from theano.sandbox.wraplinker import ProfileMode
from theano import gof, Op, Apply
from theano.tensor import blas, opt
# numpy: aa_numpy.py
# c : aa.cc
if 0: if 0:
class Opt(object): class Opt(object):
...@@ -131,7 +138,7 @@ if 0: ...@@ -131,7 +138,7 @@ if 0:
self.merge(env) self.merge(env)
def linker(print_prog=True): def print_graph_linker(print_prog=True):
if 1: if 1:
imap = {None:'-'} imap = {None:'-'}
def blah(i, node, thunk): def blah(i, node, thunk):
...@@ -146,7 +153,6 @@ def linker(print_prog=True): ...@@ -146,7 +153,6 @@ def linker(print_prog=True):
print 'node ', i, node, print 'node ', i, node,
print ':'.join([imap[inp.owner] for inp in node.inputs]) print ':'.join([imap[inp.owner] for inp in node.inputs])
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs) #print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
return theano.sandbox.wraplinker.WrapLinkerMany( return theano.sandbox.wraplinker.WrapLinkerMany(
[theano.gof.OpWiseCLinker()], [theano.gof.OpWiseCLinker()],
[theano.sandbox.wraplinker.run_all [theano.sandbox.wraplinker.run_all
...@@ -184,8 +190,11 @@ class M(module.Module): ...@@ -184,8 +190,11 @@ class M(module.Module):
self.step = module.Method([x], err, updates=dict(updates)) self.step = module.Method([x], err, updates=dict(updates))
mod = M() mod = M()
#m = mod.make(mode='FAST_RUN') mode = 'FAST_RUN'
mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker()) #mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
mode = Mode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker(nice_errors=True))
mode = Mode(optimizer='fast_run', linker='c')
print mod.pretty(mode=mode)
m = mod.make(mode=mode) m = mod.make(mode=mode)
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]] neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
...@@ -200,5 +209,10 @@ t = time.time() ...@@ -200,5 +209,10 @@ t = time.time()
for i in xrange(niter): for i in xrange(niter):
err = m.step(x) err = m.step(x)
print 'time: ',time.time() - t, 'err: ', err print 'time: ',time.time() - t, 'err: ', err
mode.print_summary() try:
mode.print_summary()
pass
except:
pass
...@@ -4,6 +4,8 @@ import numpy as N ...@@ -4,6 +4,8 @@ import numpy as N
import sys import sys
import time import time
# c: aa.cc
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]] neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
lr = 0.01 lr = 0.01
...@@ -14,12 +16,20 @@ a = rng.randn(nhid) * 0.0 ...@@ -14,12 +16,20 @@ a = rng.randn(nhid) * 0.0
b = rng.randn(nout) * 0.0 b = rng.randn(nout) * 0.0
x = (rng.rand(neg, nout)-0.5) * 1.5 x = (rng.rand(neg, nout)-0.5) * 1.5
dot_time = 0.0
t = time.time() t = time.time()
for i in xrange(niter): for i in xrange(niter):
hid = N.tanh(N.dot(x, w) + a) tt = time.time()
d = N.dot(x, w)
dot_time += time.time() - tt
hid = N.tanh(d + a)
out = N.tanh(N.dot(hid, w.T) + b) tt = time.time()
d = N.dot(hid, w.T)
dot_time += time.time() - tt
out = N.tanh(d + b)
g_out = out - x g_out = out - x
err = 0.5 * N.sum(g_out**2) err = 0.5 * N.sum(g_out**2)
...@@ -28,12 +38,23 @@ for i in xrange(niter): ...@@ -28,12 +38,23 @@ for i in xrange(niter):
b -= lr * N.sum(g_hidwt, axis=0) b -= lr * N.sum(g_hidwt, axis=0)
tt = time.time()
g_hid = N.dot(g_hidwt, w) g_hid = N.dot(g_hidwt, w)
dot_time += time.time() - tt
g_hidin = g_hid * (1.0 - hid**2) g_hidin = g_hid * (1.0 - hid**2)
w -= lr * (N.dot(g_hidwt.T, hid) + N.dot(x.T, g_hidin)) tt = time.time()
d = N.dot(g_hidwt.T, hid)
dd = N.dot(x.T, g_hidin)
dot_time += time.time() - tt
gw = (d + dd)
w -= lr * gw
a -= lr * N.sum(g_hidin, axis=0) a -= lr * N.sum(g_hidin, axis=0)
print 'time: ',time.time() - t, 'err: ', err total_time = time.time() - t
print 'time: ',total_time, 'err: ', err
print ' of which', dot_time, 'was spent on dot. Fraction:', dot_time / total_time
...@@ -23,11 +23,12 @@ from op import \ ...@@ -23,11 +23,12 @@ from op import \
from opt import \ from opt import \
Optimizer, optimizer, SeqOptimizer, \ Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \ MergeOptimizer, MergeOptMerge, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \ NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, \
keep_going, warn, \ keep_going, warn, \
InplaceOptimizer, PureThenInplaceOptimizer InplaceOptimizer, PureThenInplaceOptimizer
#LocalOpKeyOptGroup, OpKeyOptimizer
from optdb import \ from optdb import \
DB, Query, \ DB, Query, \
......
...@@ -686,14 +686,15 @@ class CLinker(link.Linker): ...@@ -686,14 +686,15 @@ class CLinker(link.Linker):
instantiate.customize.add_support_code(support_code) instantiate.customize.add_support_code(support_code)
instantiate.customize.add_support_code(self.struct_code) instantiate.customize.add_support_code(self.struct_code)
instantiate.customize.add_support_code(static) instantiate.customize.add_support_code(static)
for extra_arg in ("-w", #-w means supress all warnings for extra_arg in (
): "-O2",
#"-O3", "-ffast-math",
#"-ffast-math",
#"-fprefetch-loop-arrays", #"-fprefetch-loop-arrays",
#"-ftree-vect-loop-version", #"-ftree-vect-loop-version",
#"-ftree-loop-optimize", #"-ftree-loop-optimize",
#"-ftree-vectorize"): #"-ftree-vectorize"):
"-w" #-w means supress all warnings
):
instantiate.customize.add_extra_compile_arg(extra_arg) instantiate.customize.add_extra_compile_arg(extra_arg)
for arg in self.compile_args(): for arg in self.compile_args():
instantiate.customize.add_extra_compile_arg(arg) instantiate.customize.add_extra_compile_arg(arg)
...@@ -747,6 +748,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage): ...@@ -747,6 +748,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
exc_value = exc_type(_exc_value, task) exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
execute.cthunk = cthunk
return execute return execute
...@@ -769,9 +771,12 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -769,9 +771,12 @@ class OpWiseCLinker(link.LocalLinker):
__cache__ = {} __cache__ = {}
def __init__(self, fallback_on_perform = True): def __init__(self,
fallback_on_perform = True,
nice_errors = True):
self.env = None self.env = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env: if self.env is not None and self.env is not env:
...@@ -841,7 +846,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -841,7 +846,9 @@ class OpWiseCLinker(link.LocalLinker):
else: else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = link.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = link.streamline(env, thunks, order,
no_recycling = no_recycling,
nice_errors = self.nice_errors)
return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
...@@ -849,7 +856,6 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -849,7 +856,6 @@ class OpWiseCLinker(link.LocalLinker):
def _default_checker(x, y): def _default_checker(x, y):
"""WRITEME """WRITEME
Default checker for DualLinker. This checks that the Default checker for DualLinker. This checks that the
......
...@@ -13,6 +13,7 @@ from collections import deque ...@@ -13,6 +13,7 @@ from collections import deque
import utils import utils
_creation_idx = [0]
class Apply(utils.object2): class Apply(utils.object2):
""" """
...@@ -121,6 +122,13 @@ class Apply(utils.object2): ...@@ -121,6 +122,13 @@ class Apply(utils.object2):
def __asapply__(self): def __asapply__(self):
return self return self
def __hash__(self):
if not hasattr(self, '_creation_idx'):
self._creation_idx = _creation_idx[0]
_creation_idx[0] += 1
return self._creation_idx
def clone(self): def clone(self):
"""Duplicate this Apply instance with inputs = self.inputs. """Duplicate this Apply instance with inputs = self.inputs.
...@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False):
deps(i) should behave like a pure function (no funny business with internal state) deps(i) should behave like a pure function (no funny business with internal state)
:note: :note:
deps(i) can/should be cached by the deps function to be fast deps(i) will be cached by this function (to be fast)
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
""" """
deps_cache = {} deps_cache = {}
def _deps(io): def _deps(io):
...@@ -611,8 +622,9 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -611,8 +622,9 @@ def general_toposort(r_out, deps, debug_print = False):
def io_toposort(i, o, orderings = {}): def io_toposort(i, o, orderings = {}):
"""WRITEME """WRITEME
""" """
#the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(i) iset = set(i)
def deps(obj): def deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
if isinstance(obj, Result): if isinstance(obj, Result):
......
...@@ -5,6 +5,7 @@ from type import Type ...@@ -5,6 +5,7 @@ from type import Type
import sys, traceback import sys, traceback
from copy import copy from copy import copy
from cutils import run_cthunk
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -225,9 +226,27 @@ def clear_storage_thunk(stg): ...@@ -225,9 +226,27 @@ def clear_storage_thunk(stg):
thunk.inputs = [stg] thunk.inputs = [stg]
return thunk return thunk
def streamline(env, thunks, order, no_recycling = [], profiler = None): def streamline(env, thunks, order, no_recycling = [], profiler = None, nice_errors = True):
"""WRITEME""" """WRITEME
if profiler is None:
:param env:
:param thunks: the list of program instructions
:param order: the list of apply instances that gave rise to the thunks (same order as thunks)
:param no_recycling: storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
:param profiler: deprecated
:param nice_errors: run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if profiler is not None:
raise NotImplementedError()
if nice_errors:
def f(): def f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
...@@ -237,14 +256,13 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None): ...@@ -237,14 +256,13 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None):
except: except:
raise_with_op(node) raise_with_op(node)
else: else:
# don't worry about raise_with_op, just go a little faster.
#there is a mix of python and c thunks
def f(): def f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
def g(): for thunk in thunks:
for thunk, node in zip(thunks, order): thunk()
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f return f
class LocalLinker(Linker): class LocalLinker(Linker):
......
...@@ -17,6 +17,9 @@ import sys ...@@ -17,6 +17,9 @@ import sys
_optimizer_idx = [0] _optimizer_idx = [0]
def _list_of_nodes(env):
return graph.io_toposort(env.inputs, env.outputs)
class Optimizer(object): class Optimizer(object):
"""WRITEME """WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it. An L{Optimizer} can be applied to an L{Env} to transform it.
...@@ -73,7 +76,7 @@ class FromFunctionOptimizer(Optimizer): ...@@ -73,7 +76,7 @@ class FromFunctionOptimizer(Optimizer):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def optimizer(f): def optimizer(f):
"""WRITEME""" """decorator for FromFunctionOptimizer"""
return FromFunctionOptimizer(f) return FromFunctionOptimizer(f)
...@@ -137,6 +140,10 @@ class _metadict: ...@@ -137,6 +140,10 @@ class _metadict:
try: try:
self.d[item] = value self.d[item] = value
except: except:
for i, (key,val) in enumerate(self.l):
if key == item:
self.l[i] = (item, value)
return
self.l.append((item, value)) self.l.append((item, value))
def get(self, item, default): def get(self, item, default):
try: try:
...@@ -191,7 +198,7 @@ class MergeOptimizer(Optimizer): ...@@ -191,7 +198,7 @@ class MergeOptimizer(Optimizer):
cid[r] = i cid[r] = i
inv_cid[i] = r inv_cid[i] = r
for node in graph.io_toposort(env.inputs, env.outputs): for node in _list_of_nodes(env):
node_cid = (node.op, tuple([cid[input] for input in node.inputs])) node_cid = (node.op, tuple([cid[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None) dup = inv_cid.get(node_cid, None)
success = False success = False
...@@ -229,10 +236,33 @@ def MergeOptMerge(opt): ...@@ -229,10 +236,33 @@ def MergeOptMerge(opt):
### Local Optimizers ### ### Local Optimizers ###
######################## ########################
class LocalOptimizer(Optimizer, utils.object2): class LocalOptimizer(object):
"""WRITEME""" """A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a env-based Optimizer instance.
"""
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def transform(self, node): def transform(self, node):
"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of results> to use in place of `node`'s outputs in the greater graph.
:type node: an Apply instance
"""
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
...@@ -272,7 +302,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -272,7 +302,7 @@ class LocalOptGroup(LocalOptimizer):
return repl return repl
class LocalOpKeyOptGroup(LocalOptGroup): class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME""" """WRITEME"""
def __init__(self, optimizers): def __init__(self, optimizers):
...@@ -515,9 +545,29 @@ class PatternSub(LocalOptimizer): ...@@ -515,9 +545,29 @@ class PatternSub(LocalOptimizer):
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(Optimizer):
"""WRITEME""" """Abstract class
"""
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
"""
:param local_opt: a LocalOptimizer to apply over a Env.
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
"""
self.local_opt = local_opt self.local_opt = local_opt
if ignore_newtrees == 'auto': if ignore_newtrees == 'auto':
self.ignore_newtrees = not getattr(local_opt, 'reentrant', True) self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
...@@ -526,9 +576,18 @@ class NavigatorOptimizer(Optimizer): ...@@ -526,9 +576,18 @@ class NavigatorOptimizer(Optimizer):
self.failure_callback = failure_callback self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner, chin = None): def attach_updater(self, env, importer, pruner, chin = None):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later!
"""
if self.ignore_newtrees: if self.ignore_newtrees:
importer = None importer = None
if importer is None and pruner is None: if importer is None and pruner is None:
return None return None
...@@ -542,12 +601,18 @@ class NavigatorOptimizer(Optimizer): ...@@ -542,12 +601,18 @@ class NavigatorOptimizer(Optimizer):
if chin is not None: if chin is not None:
def on_change_input(self, env, node, i, r, new_r): def on_change_input(self, env, node, i, r, new_r):
chin(node, i, r, new_r) chin(node, i, r, new_r)
u = Updater() u = Updater()
env.extend(u) env.extend(u)
return u return u
def detach_updater(self, env, u): def detach_updater(self, env, u):
"""Undo the work of attach_updater.
:param u: a return-value of attach_updater
:returns: None.
"""
if u is not None: if u is not None:
env.remove_feature(u) env.remove_feature(u)
...@@ -610,7 +675,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -610,7 +675,7 @@ class TopoOptimizer(NavigatorOptimizer):
except: except:
self.detach_updater(env, u) self.detach_updater(env, u)
raise raise
self.detach_updater(env, u)
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
...@@ -642,6 +707,7 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -642,6 +707,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
except: except:
self.detach_updater(env, u) self.detach_updater(env, u)
raise raise
self.detach_updater(env, u)
def add_requirements(self, env): def add_requirements(self, env):
""" """
...@@ -654,38 +720,70 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -654,38 +720,70 @@ class OpKeyOptimizer(NavigatorOptimizer):
# class EquilibriumOptimizer(NavigatorOptimizer): from utils import D
# """WRITEME"""
# def __init__(self, local_optimizers, failure_callback = None): class EquilibriumOptimizer(NavigatorOptimizer):
# NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback) def __init__(self,
local_optimizers,
# def apply(self, env): failure_callback = None,
# op = self.local_opt.op_key() max_depth = None,
# if isinstance(op, (list, tuple)): max_use_ratio = None):
# q = reduce(list.__iadd__, map(env.get_nodes, op)) """
# else: :param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
# q = list(env.get_nodes(op))
# def importer(node):
# if node.op == op: q.append(node)
# def pruner(node):
# if node is not current_node and node.op == op:
# try: q.remove(node)
# except ValueError: pass
# u = self.attach_updater(env, importer, pruner)
# try:
# while q:
# node = q.pop()
# current_node = node
# self.process_node(env, node)
# except:
# self.detach_updater(env, u)
# raise
"""
from utils import D super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = True,
failure_callback = failure_callback)
class EquilibriumOptimizer(NavigatorOptimizer): self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
def apply(self, env, start_from = None):
if start_from is None:
start_from = env.outputs
changed = True
max_use_abort = False
process_count = {}
while changed and not max_use_abort:
changed = False
q = deque(graph.io_toposort(env.inputs, start_from))
max_use = len(q) * self.max_use_ratio
def importer(node):
q.append(node)
def pruner(node):
if node is not current_node:
try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner)
try:
while q:
node = q.pop()
current_node = node
for lopt in self.local_optimizers:
process_count.setdefault(lopt, 0)
if process_count[lopt] > max_use:
max_use_abort = True
else:
lopt_change = self.process_node(env, node, lopt)
process_count[lopt] += 1 if lopt_change else 0
changed |= lopt_change
except:
self.detach_updater(env, u)
raise
self.detach_updater(env, u)
if max_use_abort:
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self, def __init__(self,
local_optimizers, local_optimizers,
...@@ -780,10 +878,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -780,10 +878,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
# importer(node) # importer(node)
for node in env.nodes: for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op)) tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin) u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks: while tasks:
for node in tasks.iterkeys(): for node in tasks.iterkeys():
todo = tasks.pop(node) todo = tasks.pop(node)
......
...@@ -18,7 +18,7 @@ class DB(object): ...@@ -18,7 +18,7 @@ class DB(object):
# N.B. obj is not an instance of class Optimizer. # N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example, # It is an instance of a DB.In the tests for example,
# this is not always the case. # this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer)): if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise Exception('wtf', obj) raise Exception('wtf', obj)
obj.name = name obj.name = name
......
...@@ -375,7 +375,7 @@ class TestEquilibrium(object): ...@@ -375,7 +375,7 @@ class TestEquilibrium(object):
x, y, z = map(MyResult, 'xyz') x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
print g print 'before', g
sys.stderr = sys.stdout # display pesky warnings along with stdout sys.stderr = sys.stdout # display pesky warnings along with stdout
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')), [PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
...@@ -384,7 +384,7 @@ class TestEquilibrium(object): ...@@ -384,7 +384,7 @@ class TestEquilibrium(object):
], ],
max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once
opt.optimize(g) opt.optimize(g)
print g print 'after', g
assert str(g) == '[Op4(x, y)]' assert str(g) == '[Op4(x, y)]'
......
...@@ -2,6 +2,7 @@ import gof #, gof.result ...@@ -2,6 +2,7 @@ import gof #, gof.result
import numpy #for numeric_grad import numpy #for numeric_grad
from gof.python25 import all from gof.python25 import all
import gof.utils
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
...@@ -55,17 +56,17 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -55,17 +56,17 @@ def grad_sources_inputs(sources, graph_inputs):
else: else:
gmap[r] = g_r gmap[r] = g_r
graph_outputs = gmap.keys() graph_outputs = gof.utils.uniq([r for r,g in sources])
if graph_inputs is None: if graph_inputs is None:
graph_inputs = gof.graph.inputs(graph_outputs) graph_inputs = gof.graph.inputs(graph_outputs)
for node in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__(): for node in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__():
g_outputs = [gmap.get(o,None) for o in node.outputs] g_outputs = [gmap.get(o,None) for o in node.outputs]
#if all output gradients are None, continue #if all output gradients are None, continue
if all(map(lambda x:x is None, g_outputs)): continue if all(map(lambda x:x is None, g_outputs)): continue
output_arg = g_outputs output_arg = g_outputs
input_arg = node.inputs input_arg = node.inputs
......
...@@ -235,17 +235,27 @@ class PPrinter: ...@@ -235,17 +235,27 @@ class PPrinter:
else: else:
raise TypeError('Not enough arguments to call.') raise TypeError('Not enough arguments to call.')
use_ascii = True
if use_ascii:
special = dict(middle_dot = u"\u00B7", special = dict(middle_dot = "\dot",
big_sigma = u"\u03A3") big_sigma = "\Sigma")
greek = dict(alpha = u"\u03B1", greek = dict(alpha = "\alpha",
beta = u"\u03B2", beta = "\beta",
gamma = u"\u03B3", gamma = "\gamma",
delta = u"\u03B4", delta = "\delta",
epsilon = u"\u03B5") epsilon = "\epsilon")
else:
special = dict(middle_dot = u"\u00B7",
big_sigma = u"\u03A3")
greek = dict(alpha = u"\u03B1",
beta = u"\u03B2",
gamma = u"\u03B3",
delta = u"\u03B4",
epsilon = u"\u03B5")
pprint = PPrinter() pprint = PPrinter()
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
import time import time
import numpy import numpy
from ..gof.cutils import run_cthunk
from ..gof.link import WrapLinker from ..gof.link import WrapLinker
from ..compile.mode import Mode from ..compile.mode import Mode
...@@ -107,19 +108,42 @@ class ProfileMode(Mode): ...@@ -107,19 +108,42 @@ class ProfileMode(Mode):
local_time = [0.0] local_time = [0.0]
apply_time = {} apply_time = {}
op_time = {} op_time = {}
op_cimpl = {}
def blah(i, node, *thunks): def blah(i, node, *thunks):
t0 = time.time() if 0:
for th in thunks: t0 = time.time()
th() for th in thunks:
dt = time.time() - t0 th()
dt = time.time() - t0
elif 0: #more precise timing
for th in thunks:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
for th in thunks:
if hasattr(th, 'cthunk'):
t0 = time.time()
run_cthunk(th.cthunk)
dt = time.time() - t0
else:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
pass
else:
raise Exception('one of the cases has to run the thunks!')
local_time[0] += dt local_time[0] += dt
apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt
op_time[node.op] = op_time.get(node.op, 0.0) + dt op_time[node.op] = op_time.get(node.op, 0.0) + dt
op_cimpl[node.op] = hasattr(thunks[0], 'cthunk')
self.local_time = local_time self.local_time = local_time
self.apply_time = apply_time self.apply_time = apply_time
self.op_time = op_time self.op_time = op_time
self.op_cimpl = op_cimpl
wrap_linker = WrapLinkerMany([linker], [blah]) wrap_linker = WrapLinkerMany([linker], [blah])
if optimizer: if optimizer:
...@@ -142,13 +166,20 @@ class ProfileMode(Mode): ...@@ -142,13 +166,20 @@ class ProfileMode(Mode):
atimes.sort() atimes.sort()
atimes.reverse() atimes.reverse()
for t,a in atimes[:15]: for t,a in atimes[:15]:
print ' ', t, a print '\t%.3f\t%i\t%s' % (t, a[0], a[1])
print ' ... (ignoring %i other Apply instances)'%max(0, len(atimes)-15) print ' ... (remaining %i Apply instances account for %.2f of the runtime)'\
%(max(0, len(atimes)-15), sum(t for t, a in atimes[15:]))
n_ops_to_print = 20
print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>' print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>'
otimes = [(t/local_time, a) for a, t in op_time.items()] otimes = [(t/local_time, a, self.op_cimpl[a]) for a, t in op_time.items()]
otimes.sort() otimes.sort()
otimes.reverse() otimes.reverse()
for t,a in otimes[:15]: for t,a,ci in otimes[:n_ops_to_print]:
print ' ', t, a print '\t%.3f\t%s %s' % (t, '*' if ci else ' ', a)
print ' ... (ignoring %i other kinds Ops)'%max(0, len(otimes)-15) print ' ... (remaining %i Ops account for %.2f of the runtime)'\
%(max(0, len(otimes)-n_ops_to_print), sum(t for t, a, ci in
otimes[n_ops_to_print:]))
print '(*) Op is running a c implementation'
...@@ -1089,38 +1089,9 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -1089,38 +1089,9 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations # View Operations
########################## ##########################
class TransposeInplace(Op):
view_map = {0: [0]}
def make_node(self, input):
return Apply(self, [input], [tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def grad(self, (x,), (gz,)):
return transpose(gz),
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "TransposeView"
_transpose_inplace = TransposeInplace()
def transpose(x, **kwargs): def transpose(x, **kwargs):
"""WRITEME""" dims = range(x.ndim-1, -1, -1)
return _transpose_inplace(tensor_copy(x), **kwargs) return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
class Subtensor(Op): class Subtensor(Op):
...@@ -1781,6 +1752,7 @@ class Dot(Op): ...@@ -1781,6 +1752,7 @@ class Dot(Op):
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape) e.args = e.args + (x.shape, y.shape)
raise raise
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0: if gz.type.ndim == 0:
return gz * y, gz * x return gz * y, gz * x
......
...@@ -103,16 +103,18 @@ class DimShuffle(Op): ...@@ -103,16 +103,18 @@ class DimShuffle(Op):
for i, b in enumerate(input_broadcastable): for i, b in enumerate(input_broadcastable):
if i not in new_order: if i not in new_order:
# we want to drop this dimension because it's not a value in new_order # we want to drop this dimension because it's not a value in new_order
if b == 1: if b == 1: # 1 aka True
self.drop.append(i) self.drop.append(i)
else: else:
# we cannot drop non-broadcastable dimensions # we cannot drop non-broadcastable dimensions
raise NotImplementedError("You cannot drop a non-broadcastable dimension.") raise ValueError("You cannot drop a non-broadcastable dimension.")
else: else:
i2j[i] = j i2j[i] = j
j += 1 j += 1
# transposition of non-broadcastable dimensions # transposition of non-broadcastable dimensions
# This is how the dimensions will be permuted, without accounting for the extra
# 'x' broadcastable dimensions to insert.
self.shuffle = [i2j[x] for x in new_order if x != 'x'] self.shuffle = [i2j[x] for x in new_order if x != 'x']
# list of dimensions of the output that are broadcastable and were not in the original input # list of dimensions of the output that are broadcastable and were not in the original input
...@@ -144,7 +146,8 @@ class DimShuffle(Op): ...@@ -144,7 +146,8 @@ class DimShuffle(Op):
and self.input_broadcastable == other.input_broadcastable and self.input_broadcastable == other.input_broadcastable
def __hash__(self): def __hash__(self):
return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable) return hash(type(self)) ^ hash(self.inplace) \
^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -175,13 +178,78 @@ class DimShuffle(Op): ...@@ -175,13 +178,78 @@ class DimShuffle(Op):
storage[0] = res storage[0] = res
def c_code(self, node, name, (input,), (res,), sub):
def statements(lst):
return ';\n'.join(lst) + ';'
nd_in = len(self.input_broadcastable)
nd_out = len(self.new_order)
check_input_nd = [('if (%(input)s->nd != ' + str(nd_in) + ')'
'{PyErr_SetString(PyExc_NotImplementedError, "input nd"); %(fail)s;}')]
clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}']
shape_statements = ['npy_intp dimensions[%i]'%nd_out]
shape_statements += [('dimensions['+str(i)+'] = %(input)s->dimensions['+str(o)+']')
if o != 'x' else
('dimensions['+str(i)+'] = 1')
for i, o in enumerate(self.new_order)]
strides_statements = ['npy_intp strides[%i]'%nd_out]
strides_statements += [('strides['+str(i)+'] = %(input)s->strides['+str(o)+']')
if o != 'x' else
('strides['+str(i)+'] = 0')
for i, o in enumerate(self.new_order)]
if self.inplace:
get_base = ['{ PyArrayObject * base = %(input)s', 'Py_INCREF((PyObject*)base)']
else:
get_base = [('{ PyArrayObject * base = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,'
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
alloc_output = [('%(res)s = (PyArrayObject*)PyArray_New(&PyArray_Type, '
'' + str(nd_out) + ', dimensions, '
'PyArray_TYPE(base), strides, '
'base->data, base->descr->elsize, '
'PyArray_FLAGS(base), NULL)'),
'%(res)s->base = (PyObject*)base',
'}']
full_code = statements(check_input_nd
+ clear_output
+ shape_statements
+ strides_statements
+ get_base
+ alloc_output)
if 0:
print 'C_CODE'
print ''
print self
print "IN BROAD", self.input_broadcastable
print "NEW ORDER", self.new_order
print "SHUFFLE", self.shuffle
print "AUGMENT", self.augment
print '------------'
print ''
print full_code
if 0:
import sys
sys.exit()
return full_code % dict(locals(), **sub)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
gz = as_tensor(gz) gz = as_tensor(gz)
grad_order = ['x'] * len(x.type.broadcastable) grad_order = ['x'] * len(x.type.broadcastable)
for i, v in enumerate(self.new_order): for i, v in enumerate(self.new_order):
if v != 'x': if v != 'x':
grad_order[v] = i grad_order[v] = i
return DimShuffle(gz.type.broadcastable, grad_order)(gz), return [DimShuffle(gz.type.broadcastable, grad_order, inplace=True)(Elemwise(scalar.identity)(gz))]
......
from basic import _scal_elemwise, _transpose_inplace from basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal from .. import scalar as scal
import elemwise import elemwise
from .. import printing from .. import printing
...@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left')) ...@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right')) pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right'))
transpose_inplace = _transpose_inplace def transpose_inplace(x, **kwargs):
"""WRITEME""" """Perform a transpose on a tensor without copying the underlying storage"""
dims = range(x.ndim-1, -1, -1)
return elemwise.DimShuffle(x.broadcastable, dims, inplace=True)(x)
pprint.assign(transpose_inplace, printing.MemberPrinter('T')) #pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
差异被折叠。
...@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase):
self.failUnless(i.shape == (2,3)) self.failUnless(i.shape == (2,3))
class T_transpose(unittest.TestCase):
def test0(self):
n = as_tensor(numpy.ones(()))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data == 1.0)
def test1(self):
n = as_tensor(numpy.ones(5))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data[0] == 1.0)
def test2(self):
n = as_tensor(numpy.ones((5,3)))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0] == 1.0)
def test3(self):
"""Test transpose of tensor, inplace version"""
n = as_tensor(numpy.ones((5,3,2)))
t = inplace.transpose_inplace(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (2,3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0,0] == 56.0)
def test_grad(self):
verify_grad(self, inplace.transpose_inplace, [numpy.random.rand(2, 3)])
verify_grad(self, inplace.transpose_inplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def setUp(self): def setUp(self):
Subtensor.debug = False Subtensor.debug = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论