提交 4f5945fe authored 作者: James Bergstra's avatar James Bergstra

merged

aa.x : aa.cc
g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lcblas -lgoto -lgfortran -lm
g++ -O3 -ffast-math -ftree-vectorize aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lmkl
#g++ aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lmkl
clean :
rm aa.x
......@@ -28,6 +28,7 @@ int main(int argc, char **argv)
int neg = strtol(argv[1], 0, 0);
int nout = strtol(argv[2], 0, 0);
int nin = nout;
int nhid = strtol(argv[3], 0, 0);
int niter = strtol(argv[4], 0, 0);
double lr = 0.01;
......@@ -35,8 +36,8 @@ int main(int argc, char **argv)
gsl_rng_set(rng, 234);
gsl_matrix * x = gsl_matrix_alloc(neg, nout);
gsl_matrix * w = gsl_matrix_alloc(nout, nhid);
gsl_matrix * x = gsl_matrix_alloc(neg, nin);
gsl_matrix * w = gsl_matrix_alloc(nin, nhid);
gsl_vector * a = gsl_vector_alloc(nhid);
gsl_vector * b = gsl_vector_alloc(nout);
gsl_matrix * xw = gsl_matrix_alloc(neg, nhid);
......@@ -59,11 +60,17 @@ int main(int argc, char **argv)
struct timeval tv0, tv1;
struct timeval tdot0, tdot1;
double time_of_dot = 0.0;
gettimeofday(&tv0, 0);
double err = 0.0;
for (int iter = 0; iter < niter; ++iter)
{
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i)
for (int j = 0; j < nhid; ++j)
......@@ -72,7 +79,10 @@ int main(int argc, char **argv)
hid->data[i*nhid+j] = tanh(act);
}
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0;
err = 0.0;
......@@ -90,8 +100,11 @@ int main(int argc, char **argv)
if (1)
{
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i)
......@@ -101,14 +114,19 @@ int main(int argc, char **argv)
a->data[j] -= lr * g_hid->data[i*nhid+j];
}
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i];
}
}
gettimeofday(&tv1, 0);
fprintf(stdout, "took = %lfs to get err %lf\n", pytime(&tv1) - pytime(&tv0), 0.5 * err);
double total_time = pytime(&tv1) - pytime(&tv0);
fprintf(stdout, "took = %lfs to get err %lf\n", total_time, 0.5 * err);
fprintf(stdout, "... of which %.2lfs was spent in dgemm (fraction: %.2lf)\n", time_of_dot, time_of_dot / total_time);
//skip freeing
return 0;
}
......
......@@ -11,6 +11,9 @@ import theano.sandbox.wraplinker
from theano.compile import module, Mode
from theano.sandbox.wraplinker import ProfileMode
# numpy: aa_numpy.py
# c : aa.cc
if 0:
class Opt(object):
merge = theano.gof.MergeOptimizer()
......@@ -131,7 +134,7 @@ if 0:
self.merge(env)
def linker(print_prog=True):
def print_graph_linker(print_prog=True):
if 1:
imap = {None:'-'}
def blah(i, node, thunk):
......@@ -146,7 +149,6 @@ def linker(print_prog=True):
print 'node ', i, node,
print ':'.join([imap[inp.owner] for inp in node.inputs])
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
return theano.sandbox.wraplinker.WrapLinkerMany(
[theano.gof.OpWiseCLinker()],
[theano.sandbox.wraplinker.run_all
......@@ -184,8 +186,9 @@ class M(module.Module):
self.step = module.Method([x], err, updates=dict(updates))
mod = M()
#m = mod.make(mode='FAST_RUN')
#mode = 'FAST_RUN'
mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
print mod.pretty(mode=mode)
m = mod.make(mode=mode)
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
......@@ -200,5 +203,10 @@ t = time.time()
for i in xrange(niter):
err = m.step(x)
print 'time: ',time.time() - t, 'err: ', err
mode.print_summary()
try:
mode.print_summary()
pass
except:
raise
......@@ -4,6 +4,8 @@ import numpy as N
import sys
import time
# c: aa.cc
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
lr = 0.01
......@@ -14,12 +16,20 @@ a = rng.randn(nhid) * 0.0
b = rng.randn(nout) * 0.0
x = (rng.rand(neg, nout)-0.5) * 1.5
dot_time = 0.0
t = time.time()
for i in xrange(niter):
hid = N.tanh(N.dot(x, w) + a)
tt = time.time()
d = N.dot(x, w)
dot_time += time.time() - tt
hid = N.tanh(d + a)
out = N.tanh(N.dot(hid, w.T) + b)
tt = time.time()
d = N.dot(hid, w.T)
dot_time += time.time() - tt
out = N.tanh(d + b)
g_out = out - x
err = 0.5 * N.sum(g_out**2)
......@@ -28,12 +38,23 @@ for i in xrange(niter):
b -= lr * N.sum(g_hidwt, axis=0)
tt = time.time()
g_hid = N.dot(g_hidwt, w)
dot_time += time.time() - tt
g_hidin = g_hid * (1.0 - hid**2)
w -= lr * (N.dot(g_hidwt.T, hid) + N.dot(x.T, g_hidin))
tt = time.time()
d = N.dot(g_hidwt.T, hid)
dd = N.dot(x.T, g_hidin)
dot_time += time.time() - tt
gw = (d + dd)
w -= lr * gw
a -= lr * N.sum(g_hidin, axis=0)
print 'time: ',time.time() - t, 'err: ', err
total_time = time.time() - t
print 'time: ',total_time, 'err: ', err
print ' of which', dot_time, 'was spent on dot. Fraction:', dot_time / total_time
......@@ -23,11 +23,12 @@ from op import \
from opt import \
Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
LocalOptimizer, local_optimizer, LocalOptGroup, \
OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, \
keep_going, warn, \
InplaceOptimizer, PureThenInplaceOptimizer
#LocalOpKeyOptGroup, OpKeyOptimizer
from optdb import \
DB, Query, \
......
......@@ -736,6 +736,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
else:
return tasks[failure_code - n]
def execute():
execute.cthunk = cthunk
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = find_task(failure)
......
......@@ -13,6 +13,7 @@ from collections import deque
import utils
_creation_idx = [0]
class Apply(utils.object2):
"""
......@@ -121,6 +122,13 @@ class Apply(utils.object2):
def __asapply__(self):
return self
def __hash__(self):
if not hasattr(self, '_creation_idx'):
self._creation_idx = _creation_idx[0]
_creation_idx[0] += 1
return self._creation_idx
def clone(self):
"""Duplicate this Apply instance with inputs = self.inputs.
......@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False):
deps(i) should behave like a pure function (no funny business with internal state)
:note:
deps(i) can/should be cached by the deps function to be fast
deps(i) will be cached by this function (to be fast)
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
"""
deps_cache = {}
def _deps(io):
......@@ -611,8 +622,9 @@ def general_toposort(r_out, deps, debug_print = False):
def io_toposort(i, o, orderings = {}):
"""WRITEME
"""
#the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(i)
def deps(obj):
def deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Result):
......
......@@ -17,6 +17,9 @@ import sys
_optimizer_idx = [0]
def _list_of_nodes(env):
return graph.io_toposort(env.inputs, env.outputs)
class Optimizer(object):
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
......@@ -73,7 +76,7 @@ class FromFunctionOptimizer(Optimizer):
env.extend(toolbox.ReplaceValidate())
def optimizer(f):
"""WRITEME"""
"""decorator for FromFunctionOptimizer"""
return FromFunctionOptimizer(f)
......@@ -137,6 +140,10 @@ class _metadict:
try:
self.d[item] = value
except:
for i, (key,val) in enumerate(self.l):
if key == item:
self.l[i] = (item, value)
return
self.l.append((item, value))
def get(self, item, default):
try:
......@@ -191,7 +198,7 @@ class MergeOptimizer(Optimizer):
cid[r] = i
inv_cid[i] = r
for node in graph.io_toposort(env.inputs, env.outputs):
for node in _list_of_nodes(env):
node_cid = (node.op, tuple([cid[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None)
success = False
......@@ -229,10 +236,33 @@ def MergeOptMerge(opt):
### Local Optimizers ###
########################
class LocalOptimizer(Optimizer, utils.object2):
"""WRITEME"""
class LocalOptimizer(object):
"""A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a env-based Optimizer instance.
"""
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def transform(self, node):
"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of results> to use in place of `node`'s outputs in the greater graph.
:type node: an Apply instance
"""
raise utils.AbstractFunctionError()
......@@ -272,7 +302,7 @@ class LocalOptGroup(LocalOptimizer):
return repl
class LocalOpKeyOptGroup(LocalOptGroup):
class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME"""
def __init__(self, optimizers):
......@@ -515,9 +545,29 @@ class PatternSub(LocalOptimizer):
class NavigatorOptimizer(Optimizer):
"""WRITEME"""
"""Abstract class
"""
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
"""
:param local_opt: a LocalOptimizer to apply over a Env.
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
"""
self.local_opt = local_opt
if ignore_newtrees == 'auto':
self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
......@@ -526,9 +576,18 @@ class NavigatorOptimizer(Optimizer):
self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner, chin = None):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later!
"""
if self.ignore_newtrees:
importer = None
if importer is None and pruner is None:
return None
......@@ -542,12 +601,18 @@ class NavigatorOptimizer(Optimizer):
if chin is not None:
def on_change_input(self, env, node, i, r, new_r):
chin(node, i, r, new_r)
u = Updater()
env.extend(u)
return u
def detach_updater(self, env, u):
"""Undo the work of attach_updater.
:param u: a return-value of attach_updater
:returns: None.
"""
if u is not None:
env.remove_feature(u)
......@@ -610,7 +675,7 @@ class TopoOptimizer(NavigatorOptimizer):
except:
self.detach_updater(env, u)
raise
self.detach_updater(env, u)
class OpKeyOptimizer(NavigatorOptimizer):
......@@ -642,6 +707,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
except:
self.detach_updater(env, u)
raise
self.detach_updater(env, u)
def add_requirements(self, env):
"""
......@@ -654,38 +720,70 @@ class OpKeyOptimizer(NavigatorOptimizer):
# class EquilibriumOptimizer(NavigatorOptimizer):
# """WRITEME"""
from utils import D
# def __init__(self, local_optimizers, failure_callback = None):
# NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
# def apply(self, env):
# op = self.local_opt.op_key()
# if isinstance(op, (list, tuple)):
# q = reduce(list.__iadd__, map(env.get_nodes, op))
# else:
# q = list(env.get_nodes(op))
# def importer(node):
# if node.op == op: q.append(node)
# def pruner(node):
# if node is not current_node and node.op == op:
# try: q.remove(node)
# except ValueError: pass
# u = self.attach_updater(env, importer, pruner)
# try:
# while q:
# node = q.pop()
# current_node = node
# self.process_node(env, node)
# except:
# self.detach_updater(env, u)
# raise
class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
"""
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
"""
from utils import D
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = True,
failure_callback = failure_callback)
class EquilibriumOptimizer(NavigatorOptimizer):
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
def apply(self, env, start_from = None):
if start_from is None:
start_from = env.outputs
changed = True
max_use_abort = False
process_count = {}
while changed and not max_use_abort:
changed = False
q = deque(graph.io_toposort(env.inputs, start_from))
max_use = len(q) * self.max_use_ratio
def importer(node):
q.append(node)
def pruner(node):
if node is not current_node:
try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner)
try:
while q:
node = q.pop()
current_node = node
for lopt in self.local_optimizers:
process_count.setdefault(lopt, 0)
if process_count[lopt] > max_use:
max_use_abort = True
else:
lopt_change = self.process_node(env, node, lopt)
process_count[lopt] += 1 if lopt_change else 0
changed |= lopt_change
except:
self.detach_updater(env, u)
raise
self.detach_updater(env, u)
if max_use_abort:
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
......@@ -780,10 +878,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
# importer(node)
for node in env.nodes:
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
......
......@@ -18,7 +18,7 @@ class DB(object):
# N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer)):
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise Exception('wtf', obj)
obj.name = name
......
......@@ -375,7 +375,7 @@ class TestEquilibrium(object):
x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y))
g = Env([x, y, z], [e])
print g
print 'before', g
sys.stderr = sys.stdout # display pesky warnings along with stdout
opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
......@@ -384,7 +384,7 @@ class TestEquilibrium(object):
],
max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once
opt.optimize(g)
print g
print 'after', g
assert str(g) == '[Op4(x, y)]'
......
......@@ -2,6 +2,7 @@ import gof #, gof.result
import numpy #for numeric_grad
from gof.python25 import all
import gof.utils
_msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients'
......@@ -55,17 +56,17 @@ def grad_sources_inputs(sources, graph_inputs):
else:
gmap[r] = g_r
graph_outputs = gmap.keys()
graph_outputs = gof.utils.uniq([r for r,g in sources])
if graph_inputs is None:
graph_inputs = gof.graph.inputs(graph_outputs)
for node in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__():
g_outputs = [gmap.get(o,None) for o in node.outputs]
#if all output gradients are None, continue
if all(map(lambda x:x is None, g_outputs)): continue
output_arg = g_outputs
input_arg = node.inputs
......
......@@ -2,6 +2,7 @@ from __future__ import absolute_import
import time
import numpy
from ..gof.cutils import run_cthunk
from ..gof.link import WrapLinker
from ..compile.mode import Mode
......@@ -107,19 +108,42 @@ class ProfileMode(Mode):
local_time = [0.0]
apply_time = {}
op_time = {}
op_cimpl = {}
def blah(i, node, *thunks):
t0 = time.time()
for th in thunks:
th()
dt = time.time() - t0
if 0:
t0 = time.time()
for th in thunks:
th()
dt = time.time() - t0
elif 0: #more precise timing
for th in thunks:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
for th in thunks:
if hasattr(th, 'cthunk'):
t0 = time.time()
run_cthunk(th.cthunk)
dt = time.time() - t0
else:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
pass
else:
raise Exception('one of the cases has to run the thunks!')
local_time[0] += dt
apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt
op_time[node.op] = op_time.get(node.op, 0.0) + dt
op_cimpl[node.op] = hasattr(thunks[0], 'cthunk')
self.local_time = local_time
self.apply_time = apply_time
self.op_time = op_time
self.op_cimpl = op_cimpl
wrap_linker = WrapLinkerMany([linker], [blah])
if optimizer:
......@@ -142,13 +166,18 @@ class ProfileMode(Mode):
atimes.sort()
atimes.reverse()
for t,a in atimes[:15]:
print ' ', t, a
print ' ... (ignoring %i other Apply instances)'%max(0, len(atimes)-15)
print '\t%.3f\t%i\t%s' % (t, a[0], a[1])
print ' ... (remaining %i Apply instances account for %.2f of the runtime)'\
%(max(0, len(atimes)-15), sum(t for t, a in atimes[15:]))
print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>'
otimes = [(t/local_time, a) for a, t in op_time.items()]
otimes = [(t/local_time, a, self.op_cimpl[a]) for a, t in op_time.items()]
otimes.sort()
otimes.reverse()
for t,a in otimes[:15]:
print ' ', t, a
print ' ... (ignoring %i other kinds Ops)'%max(0, len(otimes)-15)
for t,a,ci in otimes[:15]:
print '\t%.3f\t%s %s' % (t, '*' if ci else ' ', a)
print ' ... (remaining %i Ops account for %.2f of the runtime)'\
%(max(0, len(otimes)-15), sum(t for t, a, ci in otimes[15:]))
print '(*) Op is running a c implementation'
......@@ -1089,36 +1089,41 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations
##########################
class TransposeInplace(Op):
view_map = {0: [0]}
def make_node(self, input):
return Apply(self, [input], [tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def grad(self, (x,), (gz,)):
return transpose(gz),
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
if 0:
class _TransposeInplace(Op):
view_map = {0: [0]}
def make_node(self, input):
return Apply(self, [input], [tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def grad(self, (x,), (gz,)):
return transpose(gz),
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "TransposeView"
def __str__(self):
return "TransposeView"
_transpose_inplace = _TransposeInplace()
_transpose_inplace = TransposeInplace()
def _old_transpose(x, **kwargs):
"""WRITEME"""
return _transpose_inplace(tensor_copy(x), **kwargs)
def transpose(x, **kwargs):
"""WRITEME"""
return _transpose_inplace(tensor_copy(x), **kwargs)
dims = range(x.ndim-1, -1, -1)
return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
......
......@@ -181,7 +181,7 @@ class DimShuffle(Op):
for i, v in enumerate(self.new_order):
if v != 'x':
grad_order[v] = i
return DimShuffle(gz.type.broadcastable, grad_order)(gz),
return [DimShuffle(gz.type.broadcastable, grad_order, inplace=True)(Elemwise(scalar.identity)(gz))]
......
from basic import _scal_elemwise, _transpose_inplace
from basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal
import elemwise
from .. import printing
......@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right'))
transpose_inplace = _transpose_inplace
"""WRITEME"""
def transpose_inplace(x, **kwargs):
"""Perform a transpose on a tensor without copying the underlying storage"""
dims = range(x.ndim-1, -1, -1)
return elemwise.DimShuffle(x.broadcastable, dims, inplace=True)(x)
pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
#pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
......@@ -50,7 +50,8 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
(T.Subtensor([slice(0, 1)]), (T.shape, 'a')),
(T.Subtensor([slice(1, 2)]), (T.shape, 'b')))),
T.constant(1.0), 'a', 'b', T.constant(1.0)),
allow_multiple_clients = False)
allow_multiple_clients = False)
def _insert_inplace_optimizer(env):
......@@ -216,6 +217,13 @@ register_canonicalize(local_shape_lift_dot)
################
def encompasses_broadcastable(b1, b2):
"""
Returns True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite.
:param b1: the broadcastable attribute of a tensor type
:param b2: the broadcastable attribute of a tensor type
"""
if len(b1) < len(b2):
return False
b1 = b1[-len(b2):]
......@@ -330,6 +338,7 @@ def local_fill_cut(node):
register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' )
@gof.local_optimizer([None, T.fill])
def local_fill_sink(node):
......@@ -550,6 +559,7 @@ def local_neg_to_mul(node):
return [-1 * node.inputs[0]]
else:
return False
register_canonicalize(local_neg_to_mul)
@gof.local_optimizer([T.mul])
def local_mul_to_neg(node):
......@@ -557,6 +567,7 @@ def local_mul_to_neg(node):
return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])]
else:
return False
register_specialize(local_mul_to_neg)
@gof.local_optimizer([T.div])
def local_div_to_inv(node):
......@@ -564,10 +575,57 @@ def local_div_to_inv(node):
return [T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))]
else:
return False
register_canonicalize(local_neg_to_mul)
register_specialize(local_mul_to_neg)
register_specialize(local_div_to_inv)
@gof.local_optimizer([T.inv])
def local_inv_canon(node):
if node.op == T.inv:
return [T.pow(node.inputs[0], -1.0)]
else:
return False
register_canonicalize(local_inv_canon)
@gof.local_optimizer([T.pow])
def local_pow_canonicalize(node):
if node.op == T.pow:
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1.0):
return [T.fill(node.inputs[1], node.inputs[0])]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0.0):
#extra fills here are to make sure the size of the output stays constant.
return [T.fill(node.inputs[0], T.fill(node.inputs[1], 1.0))]
else:
return False
register_canonicalize(local_pow_canonicalize)
@gof.local_optimizer([T.pow])
def local_pow_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.pow:
#the idea here is that we have pow(x, y)
xsym = node.inputs[0]
ysym = node.inputs[1]
y = local_mul_canonizer.get_constant(ysym)
if (y is not None) \
and encompasses_broadcastable(xsym.type.broadcastable, ysym.type.broadcastable):
if N.all(y == 2.0):
return [T.sqr(xsym)]
if N.all(y == 1.0):
return [xsym]
if N.all(y == 0.0):
return [T.fill(xsym, 1.0)]
if N.all(y == 0.5):
return [T.sqrt(xsym)]
if N.all(y == -0.5):
return [T.inv(T.sqrt(xsym))]
if N.all(y == -1.0):
return [T.inv(xsym)]
if N.all(y == -2.0):
return [T.inv(T.sqr(xsym))]
else:
return False
register_specialize(local_pow_specialize)
register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
......
......@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase):
self.failUnless(i.shape == (2,3))
class T_transpose(unittest.TestCase):
def test0(self):
n = as_tensor(numpy.ones(()))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data == 1.0)
def test1(self):
n = as_tensor(numpy.ones(5))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data[0] == 1.0)
def test2(self):
n = as_tensor(numpy.ones((5,3)))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0] == 1.0)
def test3(self):
"""Test transpose of tensor, inplace version"""
n = as_tensor(numpy.ones((5,3,2)))
t = inplace.transpose_inplace(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (2,3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0,0] == 56.0)
def test_grad(self):
verify_grad(self, inplace.transpose_inplace, [numpy.random.rand(2, 3)])
verify_grad(self, inplace.transpose_inplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase):
def setUp(self):
Subtensor.debug = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论