提交 d653a636 authored 作者: desjagui@atchoum.iro.umontreal.ca's avatar desjagui@atchoum.iro.umontreal.ca

merge

aa.x : aa.cc aa.x : aa.cc
g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lcblas -lgoto -lgfortran -lm g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl ${THEANO_BLAS_LDFLAGS}
clean : clean :
rm aa.x rm aa.x
...@@ -28,6 +28,7 @@ int main(int argc, char **argv) ...@@ -28,6 +28,7 @@ int main(int argc, char **argv)
int neg = strtol(argv[1], 0, 0); int neg = strtol(argv[1], 0, 0);
int nout = strtol(argv[2], 0, 0); int nout = strtol(argv[2], 0, 0);
int nin = nout;
int nhid = strtol(argv[3], 0, 0); int nhid = strtol(argv[3], 0, 0);
int niter = strtol(argv[4], 0, 0); int niter = strtol(argv[4], 0, 0);
double lr = 0.01; double lr = 0.01;
...@@ -35,8 +36,8 @@ int main(int argc, char **argv) ...@@ -35,8 +36,8 @@ int main(int argc, char **argv)
gsl_rng_set(rng, 234); gsl_rng_set(rng, 234);
gsl_matrix * x = gsl_matrix_alloc(neg, nout); gsl_matrix * x = gsl_matrix_alloc(neg, nin);
gsl_matrix * w = gsl_matrix_alloc(nout, nhid); gsl_matrix * w = gsl_matrix_alloc(nin, nhid);
gsl_vector * a = gsl_vector_alloc(nhid); gsl_vector * a = gsl_vector_alloc(nhid);
gsl_vector * b = gsl_vector_alloc(nout); gsl_vector * b = gsl_vector_alloc(nout);
gsl_matrix * xw = gsl_matrix_alloc(neg, nhid); gsl_matrix * xw = gsl_matrix_alloc(neg, nhid);
...@@ -59,11 +60,17 @@ int main(int argc, char **argv) ...@@ -59,11 +60,17 @@ int main(int argc, char **argv)
struct timeval tv0, tv1; struct timeval tv0, tv1;
struct timeval tdot0, tdot1;
double time_of_dot = 0.0;
gettimeofday(&tv0, 0); gettimeofday(&tv0, 0);
double err = 0.0; double err = 0.0;
for (int iter = 0; iter < niter; ++iter) for (int iter = 0; iter < niter; ++iter)
{ {
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw); gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i) for (int i = 0; i < neg; ++i)
for (int j = 0; j < nhid; ++j) for (int j = 0; j < nhid; ++j)
...@@ -72,7 +79,10 @@ int main(int argc, char **argv) ...@@ -72,7 +79,10 @@ int main(int argc, char **argv)
hid->data[i*nhid+j] = tanh(act); hid->data[i*nhid+j] = tanh(act);
} }
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt); gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0; for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0;
err = 0.0; err = 0.0;
...@@ -90,8 +100,11 @@ int main(int argc, char **argv) ...@@ -90,8 +100,11 @@ int main(int argc, char **argv)
if (1) if (1)
{ {
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid); gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w); gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < neg; ++i) for (int i = 0; i < neg; ++i)
...@@ -101,14 +114,19 @@ int main(int argc, char **argv) ...@@ -101,14 +114,19 @@ int main(int argc, char **argv)
a->data[j] -= lr * g_hid->data[i*nhid+j]; a->data[j] -= lr * g_hid->data[i*nhid+j];
} }
gettimeofday(&tdot0, 0);
gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w); gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w);
gettimeofday(&tdot1, 0);
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i]; for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i];
} }
} }
gettimeofday(&tv1, 0); gettimeofday(&tv1, 0);
fprintf(stdout, "took = %lfs to get err %lf\n", pytime(&tv1) - pytime(&tv0), 0.5 * err); double total_time = pytime(&tv1) - pytime(&tv0);
fprintf(stdout, "took = %lfs to get err %lf\n", total_time, 0.5 * err);
fprintf(stdout, "... of which %.2lfs was spent in dgemm (fraction: %.2lf)\n", time_of_dot, time_of_dot / total_time);
//skip freeing //skip freeing
return 0; return 0;
} }
......
...@@ -8,7 +8,15 @@ import theano ...@@ -8,7 +8,15 @@ import theano
import theano.tensor as T import theano.tensor as T
import theano.sandbox import theano.sandbox
import theano.sandbox.wraplinker import theano.sandbox.wraplinker
from theano.compile import module from theano.compile import module, Mode
from theano.sandbox.wraplinker import ProfileMode
from theano import gof, Op, Apply
from theano.tensor import blas, opt
# numpy: aa_numpy.py
# c : aa.cc
if 0: if 0:
class Opt(object): class Opt(object):
...@@ -130,32 +138,29 @@ if 0: ...@@ -130,32 +138,29 @@ if 0:
self.merge(env) self.merge(env)
def linker(print_prog=False): def print_graph_linker(print_prog=True):
if 1: if 1:
print 'wtf?' imap = {None:'-'}
#return theano.gof.OpWiseCLinker() def blah(i, node, thunk):
imap = {None:'-'} imap[node] = str(i)
def blah(i, node, thunk): if print_prog:# and node.op.__class__ is T.DimShuffle:
imap[node] = str(i) if False and node.op == T.DimShuffle((), ['x', 'x'], inplace = True):
if print_prog:# and node.op.__class__ is T.DimShuffle: print node.op == T.DimShuffle((), ['x', 'x'], inplace = True),
if False and node.op == T.DimShuffle((), ['x', 'x'], inplace = True): print node.inputs[0], type(node.inputs[0]),
print node.op == T.DimShuffle((), ['x', 'x'], inplace = True), print node.inputs[0].equals(T.constant(2)),
print node.inputs[0], type(node.inputs[0]), outputs = node.outputs
print node.inputs[0].equals(T.constant(2)), inputs = theano.gof.graph.inputs(outputs)
outputs = node.outputs print 'node ', i, node,
inputs = theano.gof.graph.inputs(outputs) print ':'.join([imap[inp.owner] for inp in node.inputs])
print 'node ', i, node, #print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
print ':'.join([imap[inp.owner] for inp in node.inputs]) return theano.sandbox.wraplinker.WrapLinkerMany(
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs) [theano.gof.OpWiseCLinker()],
[theano.sandbox.wraplinker.run_all
return theano.sandbox.wraplinker.WrapLinkerMany( ,blah
[theano.gof.OpWiseCLinker()], #,theano.sandbox.wraplinker.numpy_notall_isfinite
[theano.sandbox.wraplinker.run_all ])
,blah else:
#,theano.sandbox.wraplinker.numpy_notall_isfinite return theano.gof.OpWiseCLinker()
])
else:
return theano.gof.OpWiseCLinker()
class M(module.Module): class M(module.Module):
...@@ -167,11 +172,14 @@ class M(module.Module): ...@@ -167,11 +172,14 @@ class M(module.Module):
self.a = module.Member(T.vector('a')) # hid bias self.a = module.Member(T.vector('a')) # hid bias
self.b = module.Member(T.vector('b')) # output bias self.b = module.Member(T.vector('b')) # output bias
hid = T.tanh(T.dot(x, self.w) + self.a) self.hid = T.tanh(T.dot(x, self.w) + self.a)
hid = self.hid
out = T.tanh(T.dot(hid, self.w.T) + self.b) self.out = T.tanh(T.dot(hid, self.w.T) + self.b)
out = self.out
err = 0.5 * T.sum((out - x)**2) self.err = 0.5 * T.sum((out - x)**2)
err = self.err
params = [self.w, self.a, self.b] params = [self.w, self.a, self.b]
...@@ -182,7 +190,13 @@ class M(module.Module): ...@@ -182,7 +190,13 @@ class M(module.Module):
self.step = module.Method([x], err, updates=dict(updates)) self.step = module.Method([x], err, updates=dict(updates))
mod = M() mod = M()
m = mod.make(mode='FAST_RUN') mode = 'FAST_RUN'
#mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
mode = Mode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker(nice_errors=True))
mode = Mode(optimizer='fast_run', linker='c')
mode = Mode(optimizer='fast_run', linker='c|py')
print mod.pretty(mode=mode)
m = mod.make(mode=mode)
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]] neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
rng = numpy.random.RandomState(342) rng = numpy.random.RandomState(342)
...@@ -196,4 +210,10 @@ t = time.time() ...@@ -196,4 +210,10 @@ t = time.time()
for i in xrange(niter): for i in xrange(niter):
err = m.step(x) err = m.step(x)
print 'time: ',time.time() - t, 'err: ', err print 'time: ',time.time() - t, 'err: ', err
try:
mode.print_summary()
pass
except:
pass
#!/usr/bin/env python2.5
from __future__ import absolute_import
import numpy as N
import sys
import time
# c: aa.cc
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
lr = 0.01
rng = N.random.RandomState(342)
w = rng.rand(nout, nhid)
a = rng.randn(nhid) * 0.0
b = rng.randn(nout) * 0.0
x = (rng.rand(neg, nout)-0.5) * 1.5
dot_time = 0.0
t = time.time()
for i in xrange(niter):
tt = time.time()
d = N.dot(x, w)
dot_time += time.time() - tt
hid = N.tanh(d + a)
tt = time.time()
d = N.dot(hid, w.T)
dot_time += time.time() - tt
out = N.tanh(d + b)
g_out = out - x
err = 0.5 * N.sum(g_out**2)
g_hidwt = g_out * (1.0 - out**2)
b -= lr * N.sum(g_hidwt, axis=0)
tt = time.time()
g_hid = N.dot(g_hidwt, w)
dot_time += time.time() - tt
g_hidin = g_hid * (1.0 - hid**2)
tt = time.time()
d = N.dot(g_hidwt.T, hid)
dd = N.dot(x.T, g_hidin)
dot_time += time.time() - tt
gw = (d + dd)
w -= lr * gw
a -= lr * N.sum(g_hidin, axis=0)
total_time = time.time() - t
print 'time: ',total_time, 'err: ', err
print ' of which', dot_time, 'was spent on dot. Fraction:', dot_time / total_time
...@@ -89,8 +89,9 @@ Get the source and run the tests like this: ...@@ -89,8 +89,9 @@ Get the source and run the tests like this:
.. code-block:: bash .. code-block:: bash
hg clone http://pylearn.org/hg/theano theano hg clone http://pylearn.org/hg/theano Theano
cd theano ln -s Theano/theano <someplace on your PYTHONPATH>/theano
cd Theano
nosetests nosetests
To update your library to the latest on pylearn.org, change directory (`cd`) to this `theano` folder and type To update your library to the latest on pylearn.org, change directory (`cd`) to this `theano` folder and type
......
import unittest import unittest
from theano import gof
from theano import compile
from theano.compile.function_module import *
from theano.scalar import *
import theano import theano
from theano import tensor import numpy as N
from theano import tensor as T from theano import tensor as T
from theano.tensor import nnet as NN from theano.tensor import nnet as NN
import random
import numpy as N
from theano.compile import module as M from theano.compile import module as M
class RegressionLayer(M.Module): class Blah(M.ModuleInstance):
# self.component #refer the Module
# def __init__(self, input = None, target = None, regularize = True):
# super(Blah, self)
def initialize(self,input_size = None, target_size = None, seed = 1827,
**init):
if input_size and target_size:
# initialize w and b in a special way using input_size and target_size
sz = (input_size, target_size)
rng = N.random.RandomState(seed)
self.w = rng.uniform(size = sz, low = -0.5, high = 0.5)
self.b = N.zeros(target_size)
self.stepsize = 0.01
#we call default_initialize after as we want the parameter to superseed the default value.
M.default_initialize(self,**init)#equivalent to previous line.
def __eq__(self, other):
if not isinstance(other.component, SoftmaxXERegression1) and not isinstance(other.component, SoftmaxXERegression2):
raise NotImplemented
#we compare the member.
if (self.w==other.w).all() and (self.b==other.b).all() and self.stepsize == other.stepsize:
return True
return False
def __hash__(self):
raise NotImplemented
def fit(self, train, test):
pass
class RegressionLayer1(M.Module):
InstanceType=Blah
def __init__(self, input = None, target = None, regularize = True):
super(RegressionLayer1, self).__init__() #boilerplate
# MODEL CONFIGURATION
self.regularize = regularize
# ACQUIRE/MAKE INPUT AND TARGET
if not input:
input = T.matrix('input')
if not target:
target = T.matrix('target')
# HYPER-PARAMETERS
self.stepsize = M.Member(T.scalar()) # a stepsize for gradient descent
# PARAMETERS
self.w = M.Member(T.matrix()) #the linear transform to apply to our input points
self.b = M.Member(T.vector()) #a vector of biases, which make our transform affine instead of linear
# REGRESSION MODEL
self.activation = T.dot(input, self.w) + self.b
self.prediction = self.build_prediction()
# CLASSIFICATION COST
self.classification_cost = self.build_classification_cost(target)
# REGULARIZATION COST
self.regularization = self.build_regularization()
# TOTAL COST
self.cost = self.classification_cost
if self.regularize:
self.cost = self.cost + self.regularization
# GET THE GRADIENTS NECESSARY TO FIT OUR PARAMETERS
self.grad_w, self.grad_b = T.grad(self.cost, [self.w, self.b])
# INTERFACE METHODS
self.update = M.Method([input, target],
self.cost,
w = self.w - self.stepsize * self.grad_w,
b = self.b - self.stepsize * self.grad_b)
self.apply = M.Method(input, self.prediction)
def params(self):
return self.w, self.b
def build_regularization(self):
return T.zero() # no regularization!
class RegressionLayer2(M.Module):
def __init__(self, input = None, target = None, regularize = True): def __init__(self, input = None, target = None, regularize = True):
super(RegressionLayer, self).__init__() #boilerplate super(RegressionLayer2, self).__init__() #boilerplate
# MODEL CONFIGURATION # MODEL CONFIGURATION
self.regularize = regularize self.regularize = regularize
# ACQUIRE/MAKE INPUT AND TARGET # ACQUIRE/MAKE INPUT AND TARGET
...@@ -48,25 +110,40 @@ class RegressionLayer(M.Module): ...@@ -48,25 +110,40 @@ class RegressionLayer(M.Module):
self.apply = M.Method(input, self.prediction) self.apply = M.Method(input, self.prediction)
def params(self): def params(self):
return self.w, self.b return self.w, self.b
def _instance_initialize(self, obj, input_size = None, target_size = None, **init): def _instance_initialize(self, obj, input_size = None, target_size = None,
seed = 1827, **init):
# obj is an "instance" of this module holding values for each member and # obj is an "instance" of this module holding values for each member and
# functions for each method # functions for each method
#super(RegressionLayer, self).initialize(obj, **init)
# here we call the superclass's initialize method, which takes all the name: value
# pairs in init and sets the property with that name to the provided value
# this covers setting stepsize, l2_coef; w and b can be set that way too
if input_size and target_size: if input_size and target_size:
# initialize w and b in a special way using input_size and target_size # initialize w and b in a special way using input_size and target_size
sz = (input_size, target_size) sz = (input_size, target_size)
obj.w = N.random.uniform(size = sz, low = -0.5, high = 0.5) rng = N.random.RandomState(seed)
obj.w = rng.uniform(size = sz, low = -0.5, high = 0.5)
obj.b = N.zeros(target_size) obj.b = N.zeros(target_size)
obj.stepsize = 0.01 obj.stepsize = 0.01
# here we call the default_initialize method, which takes all the name: value
# pairs in init and sets the property with that name to the provided value
# this covers setting stepsize, l2_coef; w and b can be set that way too
# we call it after as we want the parameter to superseed the default value.
M.default_initialize(obj,**init)
def build_regularization(self): def build_regularization(self):
return T.zero() # no regularization! return T.zero() # no regularization!
class SoftmaxXERegression1(RegressionLayer1):
""" XE mean cross entropy"""
def build_prediction(self):
return NN.softmax(self.activation)
def build_classification_cost(self, target):
#self.classification_cost_matrix = target * T.log(self.prediction) + (1 - target) * T.log(1 - self.prediction)
self.classification_cost_matrix = (target - self.prediction)**2
self.classification_costs = -T.sum(self.classification_cost_matrix, axis=1)
return T.sum(self.classification_costs)
def build_regularization(self):
self.l2_coef = M.Member(T.scalar()) # we can add a hyper parameter if we need to
return self.l2_coef * T.sum(self.w * self.w)
class SoftmaxXERegression(RegressionLayer): class SoftmaxXERegression2(RegressionLayer2):
""" XE mean cross entropy""" """ XE mean cross entropy"""
def build_prediction(self): def build_prediction(self):
return NN.softmax(self.activation) return NN.softmax(self.activation)
...@@ -80,8 +157,8 @@ class SoftmaxXERegression(RegressionLayer): ...@@ -80,8 +157,8 @@ class SoftmaxXERegression(RegressionLayer):
return self.l2_coef * T.sum(self.w * self.w) return self.l2_coef * T.sum(self.w * self.w)
class T_function_module(unittest.TestCase): class T_test_wiki_module(unittest.TestCase):
def test_Klass_basic_example1(self): def test_Module_basic_example1(self):
n, c = T.scalars('nc') n, c = T.scalars('nc')
inc = theano.function([n, ((c, c + n), 0)], []) inc = theano.function([n, ((c, c + n), 0)], [])
dec = theano.function([n, ((c, c - n), inc.container[c])], []) # we need to pass inc's container in order to share dec = theano.function([n, ((c, c - n), inc.container[c])], []) # we need to pass inc's container in order to share
...@@ -93,12 +170,15 @@ class T_function_module(unittest.TestCase): ...@@ -93,12 +170,15 @@ class T_function_module(unittest.TestCase):
assert inc[c] == -1 and dec[c] == inc[c] assert inc[c] == -1 and dec[c] == inc[c]
assert plus10() == 9 assert plus10() == 9
def test_Klass_basic_example2(self): def test_Module_basic_example2(self):
m = M.Module() m = M.Module()
n = T.scalar('n') n = T.scalar('n')
m.c = M.Member(T.scalar()) # state variables must be wrapped with ModuleMember m.c = M.Member(T.scalar()) # state variables must be wrapped with ModuleMember
m.inc = M.Method(n, [], c = m.c + n) # m.c <= m.c + n m.inc = M.Method(n, [], c = m.c + n) # m.c <= m.c + n
m.dec = M.Method(n, [], c = m.c - n) # k.c <= k.c - n m.dec = M.Method(n, [], c = m.c - n) # k.c <= k.c - n
m.dec = M.Method(n, [], updates = {m.c: m.c - n})
#m.dec = M.Method(n, [], updates = {c: m.c - n})#global c don't exist
#m.dec = M.Method(n, [], m.c = m.c - n) #python don't suppor this syntax
m.plus10 = M.Method([], m.c + 10) # m.c is always accessible since it is a member of this mlass m.plus10 = M.Method([], m.c + 10) # m.c is always accessible since it is a member of this mlass
inst = m.make(c = 0) # here, we make an "instance" of the module with c initialized to 0 inst = m.make(c = 0) # here, we make an "instance" of the module with c initialized to 0
assert inst.c == 0 assert inst.c == 0
...@@ -108,7 +188,7 @@ class T_function_module(unittest.TestCase): ...@@ -108,7 +188,7 @@ class T_function_module(unittest.TestCase):
assert inst.c == -1 assert inst.c == -1
assert inst.plus10() == 9 assert inst.plus10() == 9
def test_Klass_nesting_example1(self): def test_Module_nesting_example1(self):
def make_incdec_function(): def make_incdec_function():
n, c = T.scalars('nc') n, c = T.scalars('nc')
inc = theano.function([n, ((c, c + n), 0)], []) inc = theano.function([n, ((c, c + n), 0)], [])
...@@ -126,7 +206,7 @@ class T_function_module(unittest.TestCase): ...@@ -126,7 +206,7 @@ class T_function_module(unittest.TestCase):
assert inc1['c'] == -2 and inc2['c'] == 6 assert inc1['c'] == -2 and inc2['c'] == 6
assert sum() == 4 # -2 + 6 assert sum() == 4 # -2 + 6
def test_Klass_nesting_example2(self): def test_Module_nesting_example2(self):
def make_incdec_module(): def make_incdec_module():
m = M.Module() m = M.Module()
n = T.scalar('n') n = T.scalar('n')
...@@ -140,71 +220,63 @@ class T_function_module(unittest.TestCase): ...@@ -140,71 +220,63 @@ class T_function_module(unittest.TestCase):
m.incdec2 = make_incdec_module() m.incdec2 = make_incdec_module()
m.sum = M.Method([], m.incdec1.c + m.incdec2.c) m.sum = M.Method([], m.incdec1.c + m.incdec2.c)
inst = m.make(incdec1 = dict(c=0), incdec2 = dict(c=0)) inst = m.make(incdec1 = dict(c=0), incdec2 = dict(c=0))
assert inst.incdec1.c==0 and inst.incdec2.c==0
inst.incdec1.inc(2) inst.incdec1.inc(2)
inst.incdec1.dec(4) inst.incdec1.dec(4)
inst.incdec2.inc(6) inst.incdec2.inc(6)
assert inst.incdec1.c == -2 and inst.incdec2.c == 6 assert inst.incdec1.c == -2 and inst.incdec2.c == 6
assert inst.sum() == 4 # -2 + 6 assert inst.sum() == 4 # -2 + 6
def test_Klass_Advanced_example(self): def test_Module_Advanced_example(self):
model_module = SoftmaxXERegression(regularize = False)
model = model_module.make(input_size = 10,
target_size = 1,
stepsize = 0.1)
data_x = N.random.randn(4, 10) data_x = N.random.randn(4, 10)
data_y = [ [int(x)] for x in N.random.randn(4) > 0] data_y = [ [int(x)] for x in N.random.randn(4) > 0]
print data_x def test(model):
print model = model.make(input_size = 10,
print data_y target_size = 1,
for i in xrange(1000): stepsize = 0.1)
xe = model.update(data_x, data_y) print model.stepsize
if i % 100 == 0: self.failUnless( model.w.shape == (10,1) and model.b.shape == (1,))
print i, xe assert model.stepsize == 0.1
for i in xrange(1000):
#for inputs, targets in my_training_set(): xe = model.update(data_x, data_y)
#print "cost:", model.update(inputs, targets) if i % 100 == 0:
print i, xe
pass
print "final weights:", model.w #for inputs, targets in my_training_set():
print "final biases:", model.b #print "cost:", model.update(inputs, targets)
#print "some prediction:", model.prediction(some_inputs)
print "final weights:", model.w
print "final biases:", model.b
def test_Klass_extending_klass_methods(self):
model_module = SoftmaxXERegression(regularize = False) #Print "some prediction:", model.prediction(some_inputs)
return model
m1=test(SoftmaxXERegression1(regularize = False))
m2=test(SoftmaxXERegression2(regularize = False))
print "m1",m1
print "m2",m2
print m2==m1
print m1==m2
assert m2==m1 and m1==m2
def test_Module_extending_module_methods(self):
model_module = SoftmaxXERegression1(regularize = False)
model_module.sum = M.Member(T.scalar()) # we add a module member to hold the sum model_module.sum = M.Member(T.scalar()) # we add a module member to hold the sum
model_module.update.extend(sum = model_module.sum + model_module.cost) # now update will also update sum! model_module.update.updates.update(sum = model_module.sum + model_module.cost) # now update will also update sum!
model = model_module.make(input_size = 4, model = model_module.make(input_size = 4,
target_size = 2, target_size = 2,
stepsize = 0.1, stepsize = 0.1,
sum = 0) # we mustn't forget to initialize the sum sum = 0) # we mustn't forget to initialize the sum
print model.stepsize
test = model.update([[0,0,1,0]], [[0,1]]) + model.update([[0,1,0,0]], [[1,0]]) self.failUnless( model.w.shape == (4,2) and model.b.shape == (2,))
assert model.stepsize == 0.1
test = model.update([[0,0,1,0]], [[0,1]])
test += model.update([[0,1,0,0]], [[1,0]])
assert model.sum == test assert model.sum == test
def test_Module_basic_example2_more(self):
def make_incdec_function():
n, c = T.scalars('nc')
inc = theano.function([n, ((c, c + n), 0)], [])
dec = theano.function([n, ((c, c - n), inc.container[c])], [])
return inc,dec
inc1, dec1 = make_incdec_function()
inc2, dec2 = make_incdec_function()
a, b = T.scalars('ab')
sum = theano.function([(a, inc1.container['c']), (b, inc2.container['c'])], a + b)
inc1(2)
dec1(4)
inc2(6)
assert inc1['c'] == -2 and inc2['c'] == 6
assert sum() == 4 # -2 + 6
def test_Klass_basic_example2_more(self):
m = M.Module() m = M.Module()
m2 = M.Module() m2 = M.Module()
m2.name="m2" # for better error m2.name="m2" # for better error
...@@ -231,26 +303,11 @@ class T_function_module(unittest.TestCase): ...@@ -231,26 +303,11 @@ class T_function_module(unittest.TestCase):
# self.assertRaises(m.make(c = 0), Error) # self.assertRaises(m.make(c = 0), Error)
m.inc = M.Method(n, [], updates={m2.c: m.c + n})#work! should be allowed? m.inc = M.Method(n, [], updates={m2.c: m.c + n})#work! should be allowed?
# self.assertRaises(m.make(c = 0), Error) # self.assertRaises(m.make(c = 0), Error)
# m.inc = M.Method(n, [], updates={m2.c: m2.c + n})#work! should be allowed? # m.inc = M.Method(n, [], updates={m.c: m2.c + m.c+ n})#work! should be allowed?
m2.inc = M.Method(n, [], updates={m2.c: m2.c + 2*m.c+ n})#work! should be allowed?
# self.assertRaises(m.make(c = 0), Error) # self.assertRaises(m.make(c = 0), Error)
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main
if 0: main("test_wiki")
unittest.main()
elif 1:
module = __import__("test_wiki")
tests = unittest.TestLoader().loadTestsFromModule(module)
tests.debug()
else:
testcases = []
testcases.append(T_function_module)
#<testsuite boilerplate>
testloader = unittest.TestLoader()
suite = unittest.TestSuite()
for testcase in testcases:
suite.addTest(testloader.loadTestsFromTestCase(testcase))
unittest.TextTestRunner(verbosity=2).run(suite)
#</boilerplate>
...@@ -664,6 +664,10 @@ class ComponentList(Composite): ...@@ -664,6 +664,10 @@ class ComponentList(Composite):
return self.__class__(*[c.dup() for c in self._components]) return self.__class__(*[c.dup() for c in self._components])
def default_initialize(self, init = {}, **kwinit):
for k, initv in dict(init, **kwinit).iteritems():
self[k] = initv
class ComponentDictInstance(CompositeInstance): class ComponentDictInstance(CompositeInstance):
""" """
ComponentDictInstance is meant to be instantiated by ComponentDict. ComponentDictInstance is meant to be instantiated by ComponentDict.
......
...@@ -23,11 +23,12 @@ from op import \ ...@@ -23,11 +23,12 @@ from op import \
from opt import \ from opt import \
Optimizer, optimizer, SeqOptimizer, \ Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \ MergeOptimizer, MergeOptMerge, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \ NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, \
keep_going, warn, \ keep_going, warn, \
InplaceOptimizer, PureThenInplaceOptimizer InplaceOptimizer, PureThenInplaceOptimizer
#LocalOpKeyOptGroup, OpKeyOptimizer
from optdb import \ from optdb import \
DB, Query, \ DB, Query, \
......
...@@ -686,7 +686,16 @@ class CLinker(link.Linker): ...@@ -686,7 +686,16 @@ class CLinker(link.Linker):
instantiate.customize.add_support_code(support_code) instantiate.customize.add_support_code(support_code)
instantiate.customize.add_support_code(self.struct_code) instantiate.customize.add_support_code(self.struct_code)
instantiate.customize.add_support_code(static) instantiate.customize.add_support_code(static)
instantiate.customize.add_extra_compile_arg("-w") for extra_arg in (
"-O2",
"-ffast-math",
#"-fprefetch-loop-arrays",
#"-ftree-vect-loop-version",
#"-ftree-loop-optimize",
#"-ftree-vectorize"):
"-w" #-w means supress all warnings
):
instantiate.customize.add_extra_compile_arg(extra_arg)
for arg in self.compile_args(): for arg in self.compile_args():
instantiate.customize.add_extra_compile_arg(arg) instantiate.customize.add_extra_compile_arg(arg)
for header in self.headers(): for header in self.headers():
...@@ -739,6 +748,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage): ...@@ -739,6 +748,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
exc_value = exc_type(_exc_value, task) exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
execute.cthunk = cthunk
return execute return execute
...@@ -761,9 +771,12 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -761,9 +771,12 @@ class OpWiseCLinker(link.LocalLinker):
__cache__ = {} __cache__ = {}
def __init__(self, fallback_on_perform = True): def __init__(self,
fallback_on_perform = True,
nice_errors = True):
self.env = None self.env = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env: if self.env is not None and self.env is not env:
...@@ -833,7 +846,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -833,7 +846,9 @@ class OpWiseCLinker(link.LocalLinker):
else: else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = link.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = link.streamline(env, thunks, order,
no_recycling = no_recycling,
nice_errors = self.nice_errors)
return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
...@@ -841,7 +856,6 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -841,7 +856,6 @@ class OpWiseCLinker(link.LocalLinker):
def _default_checker(x, y): def _default_checker(x, y):
"""WRITEME """WRITEME
Default checker for DualLinker. This checks that the Default checker for DualLinker. This checks that the
......
...@@ -13,6 +13,7 @@ from collections import deque ...@@ -13,6 +13,7 @@ from collections import deque
import utils import utils
_creation_idx = [0]
class Apply(utils.object2): class Apply(utils.object2):
""" """
...@@ -121,6 +122,13 @@ class Apply(utils.object2): ...@@ -121,6 +122,13 @@ class Apply(utils.object2):
def __asapply__(self): def __asapply__(self):
return self return self
def __hash__(self):
if not hasattr(self, '_creation_idx'):
self._creation_idx = _creation_idx[0]
_creation_idx[0] += 1
return self._creation_idx
def clone(self): def clone(self):
"""Duplicate this Apply instance with inputs = self.inputs. """Duplicate this Apply instance with inputs = self.inputs.
...@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False):
deps(i) should behave like a pure function (no funny business with internal state) deps(i) should behave like a pure function (no funny business with internal state)
:note: :note:
deps(i) can/should be cached by the deps function to be fast deps(i) will be cached by this function (to be fast)
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
""" """
deps_cache = {} deps_cache = {}
def _deps(io): def _deps(io):
...@@ -611,8 +622,9 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -611,8 +622,9 @@ def general_toposort(r_out, deps, debug_print = False):
def io_toposort(i, o, orderings = {}): def io_toposort(i, o, orderings = {}):
"""WRITEME """WRITEME
""" """
#the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(i) iset = set(i)
def deps(obj): def deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
if isinstance(obj, Result): if isinstance(obj, Result):
......
...@@ -5,6 +5,7 @@ from type import Type ...@@ -5,6 +5,7 @@ from type import Type
import sys, traceback import sys, traceback
from copy import copy from copy import copy
from cutils import run_cthunk
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -225,9 +226,27 @@ def clear_storage_thunk(stg): ...@@ -225,9 +226,27 @@ def clear_storage_thunk(stg):
thunk.inputs = [stg] thunk.inputs = [stg]
return thunk return thunk
def streamline(env, thunks, order, no_recycling = [], profiler = None): def streamline(env, thunks, order, no_recycling = [], profiler = None, nice_errors = True):
"""WRITEME""" """WRITEME
if profiler is None:
:param env:
:param thunks: the list of program instructions
:param order: the list of apply instances that gave rise to the thunks (same order as thunks)
:param no_recycling: storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
:param profiler: deprecated
:param nice_errors: run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if profiler is not None:
raise NotImplementedError()
if nice_errors:
def f(): def f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
...@@ -237,14 +256,13 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None): ...@@ -237,14 +256,13 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None):
except: except:
raise_with_op(node) raise_with_op(node)
else: else:
# don't worry about raise_with_op, just go a little faster.
#there is a mix of python and c thunks
def f(): def f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
def g(): for thunk in thunks:
for thunk, node in zip(thunks, order): thunk()
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f return f
class LocalLinker(Linker): class LocalLinker(Linker):
......
...@@ -15,6 +15,10 @@ from collections import deque, defaultdict ...@@ -15,6 +15,10 @@ from collections import deque, defaultdict
import destroyhandler as dh import destroyhandler as dh
import sys import sys
_optimizer_idx = [0]
def _list_of_nodes(env):
return graph.io_toposort(env.inputs, env.outputs)
class Optimizer(object): class Optimizer(object):
"""WRITEME """WRITEME
...@@ -23,6 +27,12 @@ class Optimizer(object): ...@@ -23,6 +27,12 @@ class Optimizer(object):
of transformation you could apply to an L{Env}. of transformation you could apply to an L{Env}.
""" """
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def apply(self, env): def apply(self, env):
"""WRITEME """WRITEME
Applies the optimization to the provided L{Env}. It may use all Applies the optimization to the provided L{Env}. It may use all
...@@ -66,12 +76,13 @@ class FromFunctionOptimizer(Optimizer): ...@@ -66,12 +76,13 @@ class FromFunctionOptimizer(Optimizer):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def optimizer(f): def optimizer(f):
"""WRITEME""" """decorator for FromFunctionOptimizer"""
return FromFunctionOptimizer(f) return FromFunctionOptimizer(f)
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME """WRITEME
Takes a list of L{Optimizer} instances and applies them Takes a list of L{Optimizer} instances and applies them
sequentially. sequentially.
...@@ -99,13 +110,13 @@ class SeqOptimizer(Optimizer, list): ...@@ -99,13 +110,13 @@ class SeqOptimizer(Optimizer, list):
raise raise
def __eq__(self, other): def __eq__(self, other):
#added to override the list's __eq__ implementation
return id(self) == id(other) return id(self) == id(other)
def __neq__(self, other): def __neq__(self, other):
#added to override the list's __neq__ implementation
return id(self) != id(other) return id(self) != id(other)
def __hash__(self):
return hash(id(self))
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -129,6 +140,10 @@ class _metadict: ...@@ -129,6 +140,10 @@ class _metadict:
try: try:
self.d[item] = value self.d[item] = value
except: except:
for i, (key,val) in enumerate(self.l):
if key == item:
self.l[i] = (item, value)
return
self.l.append((item, value)) self.l.append((item, value))
def get(self, item, default): def get(self, item, default):
try: try:
...@@ -183,7 +198,7 @@ class MergeOptimizer(Optimizer): ...@@ -183,7 +198,7 @@ class MergeOptimizer(Optimizer):
cid[r] = i cid[r] = i
inv_cid[i] = r inv_cid[i] = r
for node in graph.io_toposort(env.inputs, env.outputs): for node in _list_of_nodes(env):
node_cid = (node.op, tuple([cid[input] for input in node.inputs])) node_cid = (node.op, tuple([cid[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None) dup = inv_cid.get(node_cid, None)
success = False success = False
...@@ -221,10 +236,33 @@ def MergeOptMerge(opt): ...@@ -221,10 +236,33 @@ def MergeOptMerge(opt):
### Local Optimizers ### ### Local Optimizers ###
######################## ########################
class LocalOptimizer(utils.object2): class LocalOptimizer(object):
"""WRITEME""" """A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a env-based Optimizer instance.
"""
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def transform(self, node): def transform(self, node):
"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of results> to use in place of `node`'s outputs in the greater graph.
:type node: an Apply instance
"""
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
...@@ -264,7 +302,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -264,7 +302,7 @@ class LocalOptGroup(LocalOptimizer):
return repl return repl
class LocalOpKeyOptGroup(LocalOptGroup): class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME""" """WRITEME"""
def __init__(self, optimizers): def __init__(self, optimizers):
...@@ -507,9 +545,29 @@ class PatternSub(LocalOptimizer): ...@@ -507,9 +545,29 @@ class PatternSub(LocalOptimizer):
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(Optimizer):
"""WRITEME""" """Abstract class
"""
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
"""
:param local_opt: a LocalOptimizer to apply over a Env.
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
"""
self.local_opt = local_opt self.local_opt = local_opt
if ignore_newtrees == 'auto': if ignore_newtrees == 'auto':
self.ignore_newtrees = not getattr(local_opt, 'reentrant', True) self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
...@@ -518,9 +576,18 @@ class NavigatorOptimizer(Optimizer): ...@@ -518,9 +576,18 @@ class NavigatorOptimizer(Optimizer):
self.failure_callback = failure_callback self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner, chin = None): def attach_updater(self, env, importer, pruner, chin = None):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later!
"""
if self.ignore_newtrees: if self.ignore_newtrees:
importer = None importer = None
if importer is None and pruner is None: if importer is None and pruner is None:
return None return None
...@@ -534,12 +601,18 @@ class NavigatorOptimizer(Optimizer): ...@@ -534,12 +601,18 @@ class NavigatorOptimizer(Optimizer):
if chin is not None: if chin is not None:
def on_change_input(self, env, node, i, r, new_r): def on_change_input(self, env, node, i, r, new_r):
chin(node, i, r, new_r) chin(node, i, r, new_r)
u = Updater() u = Updater()
env.extend(u) env.extend(u)
return u return u
def detach_updater(self, env, u): def detach_updater(self, env, u):
"""Undo the work of attach_updater.
:param u: a return-value of attach_updater
:returns: None.
"""
if u is not None: if u is not None:
env.remove_feature(u) env.remove_feature(u)
...@@ -562,6 +635,13 @@ class NavigatorOptimizer(Optimizer): ...@@ -562,6 +635,13 @@ class NavigatorOptimizer(Optimizer):
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs) self.failure_callback(e, self, repl_pairs)
#DEBUG DONT PUSH
#print lopt
#print dir(lopt)
#raise
#END
return False return False
else: else:
raise raise
...@@ -602,7 +682,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -602,7 +682,7 @@ class TopoOptimizer(NavigatorOptimizer):
except: except:
self.detach_updater(env, u) self.detach_updater(env, u)
raise raise
self.detach_updater(env, u)
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
...@@ -634,6 +714,7 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -634,6 +714,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
except: except:
self.detach_updater(env, u) self.detach_updater(env, u)
raise raise
self.detach_updater(env, u)
def add_requirements(self, env): def add_requirements(self, env):
""" """
...@@ -646,176 +727,67 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -646,176 +727,67 @@ class OpKeyOptimizer(NavigatorOptimizer):
# class EquilibriumOptimizer(NavigatorOptimizer):
# """WRITEME"""
# def __init__(self, local_optimizers, failure_callback = None):
# NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
# def apply(self, env):
# op = self.local_opt.op_key()
# if isinstance(op, (list, tuple)):
# q = reduce(list.__iadd__, map(env.get_nodes, op))
# else:
# q = list(env.get_nodes(op))
# def importer(node):
# if node.op == op: q.append(node)
# def pruner(node):
# if node is not current_node and node.op == op:
# try: q.remove(node)
# except ValueError: pass
# u = self.attach_updater(env, importer, pruner)
# try:
# while q:
# node = q.pop()
# current_node = node
# self.process_node(env, node)
# except:
# self.detach_updater(env, u)
# raise
from utils import D from utils import D
class EquilibriumOptimizer(NavigatorOptimizer): class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self, def __init__(self,
local_optimizers, local_optimizers,
failure_callback = None, failure_callback = None,
max_depth = None, max_depth = None,
max_use_ratio = None): max_use_ratio = None):
"""
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
"""
super(EquilibriumOptimizer, self).__init__( super(EquilibriumOptimizer, self).__init__(
None, None,
ignore_newtrees = False, ignore_newtrees = True,
failure_callback = failure_callback) failure_callback = failure_callback)
self.local_optimizers = local_optimizers self.local_optimizers = local_optimizers
self.max_depth = max_depth self.max_depth = max_depth
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list) def apply(self, env, start_from = None):
self.tracks0 = defaultdict(list) if start_from is None:
max_depth = 0 start_from = env.outputs
for lopt in local_optimizers: changed = True
tracks = lopt.tracks() max_use_abort = False
for track in tracks: process_count = {}
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth: while changed and not max_use_abort:
raise ValueError('One of the local optimizers exceeds the maximal depth.') changed = False
for i, op in enumerate(track):
if i == 0: q = deque(graph.io_toposort(env.inputs, start_from))
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt)) max_use = len(q) * self.max_use_ratio
def importer(node):
def fetch_tracks(self, op): q.append(node)
return self.tracks[op] + self.tracks[None] def pruner(node):
if node is not current_node:
def fetch_tracks0(self, op): try: q.remove(node)
return self.tracks0[op] + self.tracks0[None] except ValueError: pass
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.nodes:
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
u = self.attach_updater(env, importer, pruner)
# op = node.op try:
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)]) while q:
# if not patterns: node = q.pop()
# return patterns current_node = node
# return self.match(node, depth + 1).intersection(patterns) for lopt in self.local_optimizers:
process_count.setdefault(lopt, 0)
if process_count[lopt] > max_use:
# def backtrack(self, node, q): max_use_abort = True
# for node2, i in node.clients: else:
# op2 = node2.op lopt_change = self.process_node(env, node, lopt)
process_count[lopt] += 1 if lopt_change else 0
changed |= lopt_change
except:
self.detach_updater(env, u)
raise
self.detach_updater(env, u)
if max_use_abort:
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
...@@ -895,5 +867,3 @@ class PureThenInplaceOptimizer(Optimizer): ...@@ -895,5 +867,3 @@ class PureThenInplaceOptimizer(Optimizer):
...@@ -4,16 +4,31 @@ import opt ...@@ -4,16 +4,31 @@ import opt
class DB(object): class DB(object):
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
def __init__(self): def __init__(self):
self.__db__ = defaultdict(set) self.__db__ = defaultdict(set)
self._names = set()
def register(self, name, obj, *tags): def register(self, name, obj, *tags):
# N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise Exception('wtf', obj)
obj.name = name obj.name = name
if name in self.__db__: if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name) raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
self.__db__[name] = set([obj]) self.__db__[name] = set([obj])
self._names.add(name)
for tag in tags: for tag in tags:
if tag in self._names:
raise ValueError('The tag of the object collides with a name.', obj, tag)
self.__db__[tag].add(obj) self.__db__[tag].add(obj)
def __query__(self, q): def __query__(self, q):
......
if 0:
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
...@@ -375,7 +375,7 @@ class TestEquilibrium(object): ...@@ -375,7 +375,7 @@ class TestEquilibrium(object):
x, y, z = map(MyResult, 'xyz') x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
print g print 'before', g
sys.stderr = sys.stdout # display pesky warnings along with stdout sys.stderr = sys.stdout # display pesky warnings along with stdout
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')), [PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
...@@ -384,7 +384,7 @@ class TestEquilibrium(object): ...@@ -384,7 +384,7 @@ class TestEquilibrium(object):
], ],
max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once
opt.optimize(g) opt.optimize(g)
print g print 'after', g
assert str(g) == '[Op4(x, y)]' assert str(g) == '[Op4(x, y)]'
......
from theano.gof.optdb import *
from unittest import TestCase
class Test_DB(TestCase):
def test_0(self):
class Opt(opt.Optimizer): #inheritance buys __hash__
name = 'blah'
db = DB()
db.register('a', Opt())
db.register('b', Opt())
db.register('c', Opt(), 'z', 'asdf')
try:
db.register('c', Opt()) #name taken
self.fail()
except ValueError, e:
if e[0].startswith("The name"):
pass
else:
raise
except:
self.fail()
try:
db.register('z', Opt()) #name collides with tag
self.fail()
except ValueError, e:
if e[0].startswith("The name"):
pass
else:
raise
except:
self.fail()
try:
db.register('u', Opt(), 'b') #name new but tag collides with name
self.fail()
except ValueError, e:
if e[0].startswith("The tag"):
pass
else:
raise
except:
self.fail()
...@@ -2,6 +2,7 @@ import gof #, gof.result ...@@ -2,6 +2,7 @@ import gof #, gof.result
import numpy #for numeric_grad import numpy #for numeric_grad
from gof.python25 import all from gof.python25 import all
import gof.utils
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
...@@ -55,17 +56,17 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -55,17 +56,17 @@ def grad_sources_inputs(sources, graph_inputs):
else: else:
gmap[r] = g_r gmap[r] = g_r
graph_outputs = gmap.keys() graph_outputs = gof.utils.uniq([r for r,g in sources])
if graph_inputs is None: if graph_inputs is None:
graph_inputs = gof.graph.inputs(graph_outputs) graph_inputs = gof.graph.inputs(graph_outputs)
for node in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__(): for node in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__():
g_outputs = [gmap.get(o,None) for o in node.outputs] g_outputs = [gmap.get(o,None) for o in node.outputs]
#if all output gradients are None, continue #if all output gradients are None, continue
if all(map(lambda x:x is None, g_outputs)): continue if all(map(lambda x:x is None, g_outputs)): continue
output_arg = g_outputs output_arg = g_outputs
input_arg = node.inputs input_arg = node.inputs
......
...@@ -235,17 +235,27 @@ class PPrinter: ...@@ -235,17 +235,27 @@ class PPrinter:
else: else:
raise TypeError('Not enough arguments to call.') raise TypeError('Not enough arguments to call.')
use_ascii = True
if use_ascii:
special = dict(middle_dot = u"\u00B7", special = dict(middle_dot = "\dot",
big_sigma = u"\u03A3") big_sigma = "\Sigma")
greek = dict(alpha = u"\u03B1", greek = dict(alpha = "\alpha",
beta = u"\u03B2", beta = "\beta",
gamma = u"\u03B3", gamma = "\gamma",
delta = u"\u03B4", delta = "\delta",
epsilon = u"\u03B5") epsilon = "\epsilon")
else:
special = dict(middle_dot = u"\u00B7",
big_sigma = u"\u03A3")
greek = dict(alpha = u"\u03B1",
beta = u"\u03B2",
gamma = u"\u03B3",
delta = u"\u03B4",
epsilon = u"\u03B5")
pprint = PPrinter() pprint = PPrinter()
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
import time import time
import numpy import numpy
from ..gof.cutils import run_cthunk
from ..gof.link import WrapLinker from ..gof.link import WrapLinker
from ..compile.mode import Mode from ..compile.mode import Mode
...@@ -103,49 +104,82 @@ def DualLinker(linkers): ...@@ -103,49 +104,82 @@ def DualLinker(linkers):
class ProfileMode(Mode): class ProfileMode(Mode):
def __init__(self, local_linker, optimizer=None): def __init__(self, linker, optimizer=None):
local_time = [0.0] local_time = [0.0]
apply_time = {} apply_time = {}
op_time = {} op_time = {}
op_cimpl = {}
def blah(i, node, *thunks): def blah(i, node, *thunks):
t0 = time.time() if 0:
for th in thunks: t0 = time.time()
th() for th in thunks:
dt = time.time() - t0 th()
dt = time.time() - t0
elif 0: #more precise timing
for th in thunks:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
for th in thunks:
if hasattr(th, 'cthunk'):
t0 = time.time()
run_cthunk(th.cthunk)
dt = time.time() - t0
else:
t0 = time.time()
th()
dt = time.time() - t0
elif 1:
pass
else:
raise Exception('one of the cases has to run the thunks!')
local_time[0] += dt local_time[0] += dt
apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt
op_time[node.op] = op_time.get(node.op, 0.0) + dt op_time[node.op] = op_time.get(node.op, 0.0) + dt
op_cimpl[node.op] = hasattr(thunks[0], 'cthunk')
self.local_time = local_time self.local_time = local_time
self.apply_time = apply_time self.apply_time = apply_time
self.op_time = op_time self.op_time = op_time
self.op_cimpl = op_cimpl
linker = WrapLinkerMany([local_linker], [blah]) wrap_linker = WrapLinkerMany([linker], [blah])
if optimizer: if optimizer:
Mode.__init__(self, linker, optimizer) super(ProfileMode, self).__init__(wrap_linker, optimizer)
else: else:
Mode.__init__(self, linker) super(ProfileMode, self).__init__(wrap_linker)
def print_summary(self): def print_summary(self):
local_time = self.local_time[0] local_time = self.local_time[0]
apply_time = self.apply_time apply_time = self.apply_time
op_time = self.op_time op_time = self.op_time
print 'local_time', local_time print ''
print 'apply-wise times' print 'ProfileMode.print_summary()'
print '---------------------------'
print ''
print 'local_time', local_time, '(Time spent running thunks)'
print 'Apply-wise summary: <fraction of local_time spent at this position> (<Apply position>, <Apply Op name>)'
atimes = [(t/local_time, (a[0], str(a[1]))) for a, t in apply_time.items()] atimes = [(t/local_time, (a[0], str(a[1]))) for a, t in apply_time.items()]
atimes.sort() atimes.sort()
atimes.reverse() atimes.reverse()
for t,a in atimes[:15]: for t,a in atimes[:15]:
print ' ', t, a print '\t%.3f\t%i\t%s' % (t, a[0], a[1])
print ' ...' #show that we are ignoring applies that don't take much time print ' ... (remaining %i Apply instances account for %.2f of the runtime)'\
print 'op-wise times' %(max(0, len(atimes)-15), sum(t for t, a in atimes[15:]))
otimes = [(t/local_time, a) for a, t in op_time.items()]
n_ops_to_print = 20
print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>'
otimes = [(t/local_time, a, self.op_cimpl[a]) for a, t in op_time.items()]
otimes.sort() otimes.sort()
otimes.reverse() otimes.reverse()
for t,a in otimes[:15]: for t,a,ci in otimes[:n_ops_to_print]:
print ' ', t, a print '\t%.3f\t%s %s' % (t, '*' if ci else ' ', a)
print ' ...' #show that we are ignoring applies that don't take much time print ' ... (remaining %i Ops account for %.2f of the runtime)'\
print sum(t for a,t in op_time.items()) %(max(0, len(otimes)-n_ops_to_print), sum(t for t, a, ci in
otimes[n_ops_to_print:]))
print '(*) Op is running a c implementation'
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from basic import * from basic import *
import opt import opt
import blas
import raw_random import raw_random
from raw_random import \ from raw_random import \
......
...@@ -13,7 +13,6 @@ from copy import copy ...@@ -13,7 +13,6 @@ from copy import copy
from .. import gof from .. import gof
from ..gof import Result, Op, utils, AbstractFunctionError, Type, Constant, Apply, Value from ..gof import Result, Op, utils, AbstractFunctionError, Type, Constant, Apply, Value
import blas # for gemm, dot
from .. import gradient from .. import gradient
import elemwise import elemwise
...@@ -403,6 +402,8 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala ...@@ -403,6 +402,8 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala
int_types = bscalar, wscalar, iscalar, lscalar int_types = bscalar, wscalar, iscalar, lscalar
float_types = fscalar, dscalar float_types = fscalar, dscalar
int_scalar_types = int_types
float_scalar_types = float_types
fvector = Tensor('float32', (False, )) fvector = Tensor('float32', (False, ))
dvector = Tensor('float64', (False, )) dvector = Tensor('float64', (False, ))
...@@ -1101,38 +1102,9 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -1101,38 +1102,9 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations # View Operations
########################## ##########################
class TransposeInplace(Op):
view_map = {0: [0]}
def make_node(self, input):
return Apply(self, [input], [tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def grad(self, (x,), (gz,)):
return transpose(gz),
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "TransposeView"
_transpose_inplace = TransposeInplace()
def transpose(x, **kwargs): def transpose(x, **kwargs):
"""WRITEME""" dims = range(x.ndim-1, -1, -1)
return _transpose_inplace(tensor_copy(x), **kwargs) return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
class Subtensor(Op): class Subtensor(Op):
...@@ -1749,6 +1721,10 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector), ...@@ -1749,6 +1721,10 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector),
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
#
# For BLAS-related ops see blas.py
#
# TODO: Dotinv should go here, Eigs, Svd, etc.
class Dot(Op): class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products. """Compute matrix-matrix, matrix-vector products and vector inner-products.
...@@ -1801,6 +1777,7 @@ class Dot(Op): ...@@ -1801,6 +1777,7 @@ class Dot(Op):
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape) e.args = e.args + (x.shape, y.shape)
raise raise
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0: if gz.type.ndim == 0:
return gz * y, gz * x return gz * y, gz * x
...@@ -1841,249 +1818,6 @@ class Outer(Op): ...@@ -1841,249 +1818,6 @@ class Outer(Op):
return "outer" return "outer"
outer = Outer() outer = Outer()
class Gemm(Op):
"""In-place version of matrix-matrix multiplication (with accumulation):
When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b)
is similar to
b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z,
whereas the bottom form is not. Gemm works in-place on the storage
associated with z, and the L{Result} returned by Gemm has a storage that
will be aliased to the storage of the z argument. Because of this in-place
computation, an L{Apply} of this op will destroy the L{Result} z on
which it operates. (See L{DestructiveOps} for an explanation of what
destroying means in the context of theano graphs. See L{BlasLapackSupport} for
more optimized linear algebra operations.)
"""
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
z, a, x, y, b = inputs
zr, xr, yr = [set(gof.view_roots(i)) for i in z,x,y]
if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
zout[0] = z
def grad(self, (z, a, x, y, b), (gz,)):
raise NotImplementedError()
def c_support_code(self):
#return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self):
return blas.ldflags()
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = %(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = %(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (%(_zout)s != %(_z)s)
{
if (%(_zout)s)
{
Py_DECREF(%(_zout)s);
}
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
}
""" % dict(locals(), **sub)
gemm = Gemm()
pprint.assign(gemm, printing.FunctionPrinter('gemm'))
######################### #########################
# Gradient # Gradient
######################### #########################
......
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import os, sys import os, sys
import scipy.weave as weave import numpy
from ..gof import utils
from ..gof import (utils, Op, Apply, view_roots, PatternSub,
InplaceOptimizer, SeqOptimizer, warn, local_optimizer)
""" from ..printing import pprint, FunctionPrinter
File: omega/blas.py from .opt import register_specialize, out2in, insert_inplace_optimizer
This file is in omega's core because it consists mostly of optimizations of the import basic as T
graphs that can be constructed from omega/core.py. The optimizations provided from ..tensor import as_tensor
by this file are aimed at the goal of inserting gemm Ops in place of more
fine-grained motifs of iadd, isub, scale, and dot. #NB: this clobbers the builtin 'compile' symbol
""" from .. import compile #to register the optimizer built by this file
def cblas_header_text(): from .blas_headers import cblas_header_text, blas_header_text
"""C header for the cblas interface"""
return """ JOSEPHS_BUG_SOLVED = False
//#include <stddef.h>
#undef __BEGIN_DECLS
#undef __END_DECLS
#ifdef __cplusplus
#define __BEGIN_DECLS extern "C" {
#define __END_DECLS }
#else
#define __BEGIN_DECLS /* empty */
#define __END_DECLS /* empty */
#endif
__BEGIN_DECLS
#define MOD %
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
__END_DECLS
"""
def blas_proto():
"""C header for the fortran blas interface"""
return """
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
@utils.memoize @utils.memoize
def ldflags(): def ldflags():
...@@ -819,68 +42,105 @@ def ldflags(): ...@@ -819,68 +42,105 @@ def ldflags():
#print "blas linking against", rval #print "blas linking against", rval
return rval return rval
def gemm_code(check_ab, a_init, b_init): class GemmRelated(Op):
mod = '%' """Base class for Gemm and Dot22
return """
const char * error_string = NULL; This class provides a kind of templated gemm Op.
"""
int type_num = _x->descr->type_num; def c_support_code(self):
int type_size = _x->descr->elsize; // in bytes #return cblas_header_text()
mod_str = """
npy_intp* Nx = _x->dimensions; #ifndef MOD
npy_intp* Ny = _y->dimensions; #define MOD %
npy_intp* Nz = _z->dimensions; #endif
"""
npy_intp* Sx = _x->strides; return blas_header_text() + mod_str
npy_intp* Sy = _y->strides; def c_headers(self):
npy_intp* Sz = _z->strides; # std.cout doesn't require the '%' symbol to print stuff...
# so it works much better with python's string-substitution stuff.
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; return ['<iostream>']
def c_libraries(self):
return ldflags()
declare_NS = """
int unit = 0; int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback; int type_num = %(_x)s->descr->type_num;
if (_y->nd != 2) goto _dot_execute_fallback; int type_size = %(_x)s->descr->elsize; // in bytes
if (_z->nd != 2) goto _dot_execute_fallback;
npy_intp* Nx = %(_x)s->dimensions;
%(check_ab)s npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions;
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT)) npy_intp* Sx = %(_x)s->strides;
goto _dot_execute_fallback; npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0; //%(_z)s->strides;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT)) //strides for x, y, z in dimensions 0, 1
goto _dot_execute_fallback; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
"""
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT)) #setup_z_Nz_Sz = None
goto _dot_execute_fallback;
check_xyz_rank2 = """
if ((_x->descr->type_num != _y->descr->type_num) if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
||(_x->descr->type_num != _z->descr->type_num)) if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
goto _dot_execute_fallback; if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
"""
check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
"""
#it is not necessary that a or b have the same type as x,y,z
check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
"""
check_dims_strides = """
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{ {
error_string = "Input dimensions do not agree"; PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
goto _dot_execute_fail; %(fail)s;
} }
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size) if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size) || (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size)) || (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{ {
goto _dot_execute_fallback; PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
} }
"""
encode_strides_in_unit = """
/* /*
encode the stride structure of _x,_y,_z into a single integer encode the stride structure of _x,_y,_z into a single integer
*/ */
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0; unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4; unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8; unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
"""
compute_strides = """
/* create appropriate strides for malformed matrices that are row or column /* create appropriate strides for malformed matrices that are row or column
* vectors * vectors
*/ */
...@@ -890,100 +150,421 @@ def gemm_code(check_ab, a_init, b_init): ...@@ -890,100 +150,421 @@ def gemm_code(check_ab, a_init, b_init):
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0]; sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1]; sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0]; sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
"""
begin_switch_typenum = """
switch (type_num) switch (type_num)
{ {
"""
case_float = """
case PyArray_FLOAT: case PyArray_FLOAT:
{ {
#define REAL float """
float a = %(a_init)s;
float b = %(b_init)s; #case_float_ab_constants = None
float* x = (float*)PyArray_DATA(_x); case_float_gemm = """
float* y = (float*)PyArray_DATA(_y); float* x = (float*)PyArray_DATA(%(_x)s);
float* z = (float*)PyArray_DATA(_z); float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit) switch(unit)
{ {
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: goto _dot_execute_fallback; default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
}; };
#undef REAL """
case_double = """
} }
break; break;
case PyArray_DOUBLE: case PyArray_DOUBLE:
{ {
#define REAL double """
double a = %(a_init)s;
double b = %(b_init)s; #case_double_ab_constants = None
double* x = (double*)PyArray_DATA(_x); case_double_gemm = """
double* y = (double*)PyArray_DATA(_y); double* x = (double*)PyArray_DATA(%(_x)s);
double* z = (double*)PyArray_DATA(_z); double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit) switch(unit)
{ {
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: goto _dot_execute_fallback; default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
}; };
#undef REAL """
end_switch_typenum = """
} }
break; break;
} }
"""
def build_gemm_call(self):
return reduce(str.__add__, (
self.declare_NS,
self.setup_z_Nz_Sz,
self.check_xyz_rank2,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.check_dims_strides,
self.encode_strides_in_unit,
self.compute_strides,
self.begin_switch_typenum,
self.case_float,
self.case_float_ab_constants,
self.case_float_gemm,
self.case_double,
self.case_double_ab_constants,
self.case_double_gemm,
self.end_switch_typenum), '')
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b)
is similar to
return 0; //success! b*z + a*dot(x,y)
_dot_execute_fallback: The difference between the two is that the top form is destructive on z,
PyErr_SetString(PyExc_NotImplementedError, whereas the bottom form is not. Gemm works in-place on the storage
"dot->execute() fallback"); associated with z, and the L{Result} returned by Gemm has a storage that
return -1; will be aliased to the storage of the z argument. Because of this in-place
computation, an L{Apply} of this op will destroy the L{Result} z on
_dot_execute_fail: which it operates. (See L{DestructiveOps} for an explanation of what
if (error_string == NULL) destroying means in the context of theano graphs. See L{BlasLapackSupport} for
PyErr_SetString(PyExc_ValueError, more optimized linear algebra operations.)
"dot->execute() cant run on these inputs");
return -1; """
E_rank = 'gemm only works for rank 2'
/* v 1 */ E_scalar = 'gemm requires scalar argument'
""" % locals() E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
# currently unused, preferring the fallback method (throwing def make_node(self, *inputs):
# NotImplementedError) for when gemm won't work. inputs = map(as_tensor, inputs)
_templated_memaligned_gemm = """ if len(inputs) != 5:
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz> raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
int general_gemm(int zM, int zN, int xN,. z, a, x, y, b = inputs
Ta a, zr, xr, yr = [set(view_roots(i)) for i in z,x,y]
Tx * x, int xm, int xn, if zr.intersection(xr):
Tx * y, int ym, int yn, raise ValueError(Gemm.E_z_uniq, (z, x))
Tb b, if zr.intersection(yr):
Tz * z, int zm, int zn) raise ValueError(Gemm.E_z_uniq, (z, y))
{ bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
for (int i = 0; i < zM; ++i) if bz != (False,False): raise ValueError(Gemm.E_rank, bz)
{ if bx != (False,False): raise ValueError(Gemm.E_rank, bx)
for (int j = 0; j < zN; ++j) if by != (False,False): raise ValueError(Gemm.E_rank, by)
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
zout[0] = z
setup_z_Nz_Sz = """
if (%(_zout)s != %(_z)s)
{ {
Tz zij = 0.0; if (%(_zout)s)
for (int k = 0; k < xN; ++k)
{ {
zij += x[i*xm+k*xn] * y[k*ym+j*yn]; Py_DECREF(%(_zout)s);
} }
z[i * zm + j * zn] *= b; %(_zout)s = %(_z)s;
z[i * zm + j * zn] += a * zij; Py_INCREF(%(_zout)s);
} }
} Nz = %(_z)s->dimensions;
} Sz = %(_z)s->strides;
""" """
case_float_ab_constants = """
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
#undef REAL
"""
case_double_ab_constants = """
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
#undef REAL
"""
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub):
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert _is_real_matrix(x)
assert y.type == x.type #makes sure y is a matrix
bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )):
try:
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
setup_z_Nz_Sz = """
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
"""
check_ab_double_or_float = ""
case_float_ab_constants = """
float a = 1.0;
float b = 0.0;
"""
case_double_ab_constants = """
double a = 1.0;
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub):
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
_dot22 = Dot22()
@local_optimizer([T.dot])
def local_dot_to_dot22(node):
if node.op == T.dot:
x,y = node.inputs
if _is_real_matrix(x) and y.type == x.type:
return [_dot22(*node.inputs)]
else:
return False
if JOSEPHS_BUG_SOLVED:
register_specialize(local_dot_to_dot22)
def _is_a(node, op, maxclients=None):
return node.owner \
and node.owner.op == op \
and len(node.clients) <= maxclients if maxclients is not None else True
def _as_scalar(res):
"""Return None or a TensorResult whose type is in T.float_scalar_types"""
if res.owner and isinstance(res.owner.op, T.DimShuffle):
return _as_scalar(res.owner.inputs[0])
elif res.type in T.float_scalar_types:
return res
elif isinstance(res, T.Constant) and res.data.size == 1:
return res.data.flatten()[0]
else:
return None
def _is_real_matrix(res):
return res.type in T.float_matrix_types \
and res.broadcastable[0] == False \
and res.broadcastable[1] == False #cope with tuple vs. list
def _as_isolated_scalar_times_matrix(res):
if _is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2:
L, R = res.owner.inputs
sL = _as_scalar(L)
sR = _as_scalar(R)
if (sL is not None) and _is_real_matrix(R):
return (sL, R)
if (sR is not None) and _is_real_matrix(L):
return (sR, L)
else:
scalars = []
matrices = []
for input in res.owner.inputs:
scalar_input = _as_scalar(input)
if scalar_input is not None:
scalars.append(scalar_input)
elif _is_real_matrix(input):
matrices.append(input)
else:
return None
if len(matrices) == 1:
rval = (T.mul(*scalars), matrices[0])
return rval
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if _is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)]
return rval
if _is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L
if _is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#print 'GEMM 1', G, L
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
return rval
elif G is L:
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)]
#print 'GEMM 2', rval
return rval
elif 1.0 != alpha:
#at the very least, move the alpha inside the gemm
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)]
#print 'GEMM 3', G, L
return rval
if recurse_flip:
return beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False)
else:
return False
@local_optimizer([T.sub])
def local_sub_to_gemm(node):
if node.op == T.sub:
L, R = node.inputs
if not _is_real_matrix(L):
return False
if not _is_real_matrix(R):
return False
tmp = _as_isolated_scalar_times_matrix(L)
try:
sL, mL = tmp
except:
sL, mL = 1.0, L
tmp = _as_isolated_scalar_times_matrix(R)
try:
sR, mR = tmp
except:
sR, mR = 1.0, R
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval
return False
if JOSEPHS_BUG_SOLVED:
register_specialize(local_sub_to_gemm)
@local_optimizer([T.add])
def local_add_to_gemm(node):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
"""
if node.op == T.add:
sM_list = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
if tmp:
sM_list.append(tmp)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
if len(sM_list) == 2:
sL, mL = sM_list[0]
sR, mR = sM_list[1]
return beta_L_plus_alpha_M(sL, mL, sR, mR)
else:
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
sL, mL = sM_list[i]
sR, mR = sM_list[j]
rval = beta_L_plus_alpha_M(sL, mL, sR, mR)
if rval:
assert len(rval) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))]
return False
if JOSEPHS_BUG_SOLVED:
register_specialize(local_add_to_gemm)
""" Header text for the C and Fortran BLAS interfaces.
There is no standard name or location for this header, so we just insert it ourselves into the C code
"""
def cblas_header_text():
"""C header for the cblas interface."""
return """
//#include <stddef.h>
#undef __BEGIN_DECLS
#undef __END_DECLS
#ifdef __cplusplus
#define __BEGIN_DECLS extern "C" {
#define __END_DECLS }
#else
#define __BEGIN_DECLS /* empty */
#define __END_DECLS /* empty */
#endif
__BEGIN_DECLS
#define MOD %
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
__END_DECLS
"""
def blas_header_text():
"""C header for the fortran blas interface"""
return """
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
def ____gemm_code(check_ab, a_init, b_init):
mod = '%'
return """
const char * error_string = NULL;
int type_num = _x->descr->type_num;
int type_size = _x->descr->elsize; // in bytes
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback;
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
%(check_ab)s
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num)
||(_x->descr->type_num != _z->descr->type_num))
goto _dot_execute_fallback;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size))
{
goto _dot_execute_fallback;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = %(a_init)s;
float b = %(b_init)s;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = %(a_init)s;
double b = %(b_init)s;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
}
return 0; //success!
_dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback");
return -1;
_dot_execute_fail:
if (error_string == NULL)
PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs");
return -1;
/* v 1 */
""" % locals()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm = """
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
...@@ -103,16 +103,18 @@ class DimShuffle(Op): ...@@ -103,16 +103,18 @@ class DimShuffle(Op):
for i, b in enumerate(input_broadcastable): for i, b in enumerate(input_broadcastable):
if i not in new_order: if i not in new_order:
# we want to drop this dimension because it's not a value in new_order # we want to drop this dimension because it's not a value in new_order
if b == 1: if b == 1: # 1 aka True
self.drop.append(i) self.drop.append(i)
else: else:
# we cannot drop non-broadcastable dimensions # we cannot drop non-broadcastable dimensions
raise NotImplementedError("You cannot drop a non-broadcastable dimension.") raise ValueError("You cannot drop a non-broadcastable dimension.")
else: else:
i2j[i] = j i2j[i] = j
j += 1 j += 1
# transposition of non-broadcastable dimensions # transposition of non-broadcastable dimensions
# This is how the dimensions will be permuted, without accounting for the extra
# 'x' broadcastable dimensions to insert.
self.shuffle = [i2j[x] for x in new_order if x != 'x'] self.shuffle = [i2j[x] for x in new_order if x != 'x']
# list of dimensions of the output that are broadcastable and were not in the original input # list of dimensions of the output that are broadcastable and were not in the original input
...@@ -144,7 +146,8 @@ class DimShuffle(Op): ...@@ -144,7 +146,8 @@ class DimShuffle(Op):
and self.input_broadcastable == other.input_broadcastable and self.input_broadcastable == other.input_broadcastable
def __hash__(self): def __hash__(self):
return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable) return hash(type(self)) ^ hash(self.inplace) \
^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -175,13 +178,78 @@ class DimShuffle(Op): ...@@ -175,13 +178,78 @@ class DimShuffle(Op):
storage[0] = res storage[0] = res
def c_code(self, node, name, (input,), (res,), sub):
def statements(lst):
return ';\n'.join(lst) + ';'
nd_in = len(self.input_broadcastable)
nd_out = len(self.new_order)
check_input_nd = [('if (%(input)s->nd != ' + str(nd_in) + ')'
'{PyErr_SetString(PyExc_NotImplementedError, "input nd"); %(fail)s;}')]
clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}']
shape_statements = ['npy_intp dimensions[%i]'%nd_out]
shape_statements += [('dimensions['+str(i)+'] = %(input)s->dimensions['+str(o)+']')
if o != 'x' else
('dimensions['+str(i)+'] = 1')
for i, o in enumerate(self.new_order)]
strides_statements = ['npy_intp strides[%i]'%nd_out]
strides_statements += [('strides['+str(i)+'] = %(input)s->strides['+str(o)+']')
if o != 'x' else
('strides['+str(i)+'] = 0')
for i, o in enumerate(self.new_order)]
if self.inplace:
get_base = ['{ PyArrayObject * base = %(input)s', 'Py_INCREF((PyObject*)base)']
else:
get_base = [('{ PyArrayObject * base = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,'
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
alloc_output = [('%(res)s = (PyArrayObject*)PyArray_New(&PyArray_Type, '
'' + str(nd_out) + ', dimensions, '
'PyArray_TYPE(base), strides, '
'base->data, base->descr->elsize, '
'PyArray_FLAGS(base), NULL)'),
'%(res)s->base = (PyObject*)base',
'}']
full_code = statements(check_input_nd
+ clear_output
+ shape_statements
+ strides_statements
+ get_base
+ alloc_output)
if 0:
print 'C_CODE'
print ''
print self
print "IN BROAD", self.input_broadcastable
print "NEW ORDER", self.new_order
print "SHUFFLE", self.shuffle
print "AUGMENT", self.augment
print '------------'
print ''
print full_code
if 0:
import sys
sys.exit()
return full_code % dict(locals(), **sub)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
gz = as_tensor(gz) gz = as_tensor(gz)
grad_order = ['x'] * len(x.type.broadcastable) grad_order = ['x'] * len(x.type.broadcastable)
for i, v in enumerate(self.new_order): for i, v in enumerate(self.new_order):
if v != 'x': if v != 'x':
grad_order[v] = i grad_order[v] = i
return DimShuffle(gz.type.broadcastable, grad_order)(gz), return [DimShuffle(gz.type.broadcastable, grad_order, inplace=True)(Elemwise(scalar.identity)(gz))]
......
from basic import _scal_elemwise, _transpose_inplace from .basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal from .. import scalar as scal
import elemwise import elemwise
from .. import printing from .. import printing
...@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left')) ...@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right')) pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right'))
transpose_inplace = _transpose_inplace def transpose_inplace(x, **kwargs):
"""WRITEME""" """Perform a transpose on a tensor without copying the underlying storage"""
dims = range(x.ndim-1, -1, -1)
return elemwise.DimShuffle(x.broadcastable, dims, inplace=True)(x)
pprint.assign(transpose_inplace, printing.MemberPrinter('T')) #pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
...@@ -203,6 +203,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -203,6 +203,7 @@ class SoftmaxWithBias(gof.Op):
for (j = 0; j < Nx[1]; ++j) for (j = 0; j < Nx[1]; ++j)
{ {
double row_ij = x_i[j * Sx] + b_i[j * Sb]; double row_ij = x_i[j * Sx] + b_i[j * Sb];
// std::cout << "1" << row_ij << "\\n";
row_max_j = (row_ij > row_max) ? j : row_max_j; row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max; row_max = (row_ij > row_max) ? row_ij : row_max;
} }
...@@ -210,13 +211,23 @@ class SoftmaxWithBias(gof.Op): ...@@ -210,13 +211,23 @@ class SoftmaxWithBias(gof.Op):
for (j = 0; j < Nx[1]; ++j) for (j = 0; j < Nx[1]; ++j)
{ {
double row_ij = x_i[j * Sx] + b_i[j * Sb]; double row_ij = x_i[j * Sx] + b_i[j * Sb];
// std::cout << "2" << row_ij << "\\n";
double sm_ij = exp(row_ij - row_max); double sm_ij = exp(row_ij - row_max);
// std::cout << "3" << sm_ij << "\\n";
sum += sm_ij; sum += sm_ij;
sm_i[j * Ssm] = sm_ij; sm_i[j * Ssm] = sm_ij;
} }
if ( (0.0 == sum) || (std::isinf(sum))) if (std::isinf(sum))
{ {
//that was our best... //that was our best...
PyErr_SetString(PyExc_ValueError, "softmax is impossible (inf)!");
%(fail)s;
}
if (0.0 == sum)
{
//that was our best...
PyErr_SetString(PyExc_ValueError, "softmax is impossible (zero)!");
%(fail)s; %(fail)s;
} }
...@@ -600,6 +611,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -600,6 +611,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
} }
if (y_i >= %(dx)s->dimensions[1]) if (y_i >= %(dx)s->dimensions[1])
{ {
PyErr_SetString(PyExc_ValueError, "y_i >= dx dimensions[1]");
%(fail)s; %(fail)s;
} }
dx_i[y_i * Sdx] -= dnll_i; dx_i[y_i * Sdx] -= dnll_i;
......
"""Tensor optimizations addressing the ops in basic.py
"""
# TODO: intelligent merge for mul/add # TODO: intelligent merge for mul/add
# TODO: 0*x -> 0 # TODO: 0*x -> 0
...@@ -30,28 +31,6 @@ def in2out(*local_opts, **kwargs): ...@@ -30,28 +31,6 @@ def in2out(*local_opts, **kwargs):
**kwargs) **kwargs)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T.sub,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
allow_multiple_clients = True),
(T.dot, 'b', 'c'))),
(T.gemm, 'd', (T.neg, 'a'), 'b', 'c', T.constant(1.0)),
allow_multiple_clients = False)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms dot(a, b) into gemm(zeros(2)(hstack(shape(a)[:1], shape(b)[1:])), 1.0, a, b, 1.0)
# The construction of the 'gemm' node may fail if, for example, a and b are not both matrices.
dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
(T.gemm, (T.Zeros(2),
(T.stack,
(T.Subtensor([slice(0, 1)]), (T.shape, 'a')),
(T.Subtensor([slice(1, 2)]), (T.shape, 'b')))),
T.constant(1.0), 'a', 'b', T.constant(1.0)),
allow_multiple_clients = False)
def _insert_inplace_optimizer(env): def _insert_inplace_optimizer(env):
""" """
...@@ -91,12 +70,6 @@ def _insert_inplace_optimizer(env): ...@@ -91,12 +70,6 @@ def _insert_inplace_optimizer(env):
break break
insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer) insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1),
insert_inplace_optimizer,
failure_callback = gof.warn))
compile.optdb.register('inplace_opt', inplace_optimizer, 99, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
...@@ -216,6 +189,13 @@ register_canonicalize(local_shape_lift_dot) ...@@ -216,6 +189,13 @@ register_canonicalize(local_shape_lift_dot)
################ ################
def encompasses_broadcastable(b1, b2): def encompasses_broadcastable(b1, b2):
"""
Returns True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite.
:param b1: the broadcastable attribute of a tensor type
:param b2: the broadcastable attribute of a tensor type
"""
if len(b1) < len(b2): if len(b1) < len(b2):
return False return False
b1 = b1[-len(b2):] b1 = b1[-len(b2):]
...@@ -330,6 +310,7 @@ def local_fill_cut(node): ...@@ -330,6 +310,7 @@ def local_fill_cut(node):
register_canonicalize(local_fill_cut) register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' )
@gof.local_optimizer([None, T.fill]) @gof.local_optimizer([None, T.fill])
def local_fill_sink(node): def local_fill_sink(node):
...@@ -524,9 +505,30 @@ class Canonizer(gof.LocalOptimizer): ...@@ -524,9 +505,30 @@ class Canonizer(gof.LocalOptimizer):
return False return False
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
if new.type != out.type: if new.dtype != out.dtype:
#new = T.fill(out, new) #new = T.fill(out, new)
new = T.fill(out, T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))(new)) elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))
new = T.fill(out, elem_op(new))
if new.broadcastable != out.broadcastable:
#this case is tricky... we need to provide exactly the same kind of broadcastable
#pattern, but only if legal...
dlen = len(new.broadcastable) - len(out.broadcastable)
if dlen > 0:
#try to take the leading ranks of new.broadcastable, which should be broadcastable
# ranks
#if this means skipping over nonbroadcastable ranks, then DimShuffle will fail
dimshuffle_op = T.DimShuffle(new.broadcastable,
range(dlen, len(new.broadcastable)))
new = dimshuffle_op(new)
elif dlen < 0:
#we have to boost up a scalar or something
dimshuffle_op = T.DimShuffle(new.broadcastable,
['x' for x in range(-dlen)] + range(0, len(new.broadcastable)))
new = dimshuffle_op(new)
# if our if's above worked, this should be true. OTW investigate.
assert new.type == out.type
return [new] return [new]
def __str__(self): def __str__(self):
...@@ -550,6 +552,7 @@ def local_neg_to_mul(node): ...@@ -550,6 +552,7 @@ def local_neg_to_mul(node):
return [-1 * node.inputs[0]] return [-1 * node.inputs[0]]
else: else:
return False return False
register_canonicalize(local_neg_to_mul)
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_to_neg(node): def local_mul_to_neg(node):
...@@ -557,6 +560,7 @@ def local_mul_to_neg(node): ...@@ -557,6 +560,7 @@ def local_mul_to_neg(node):
return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])] return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])]
else: else:
return False return False
register_specialize(local_mul_to_neg)
@gof.local_optimizer([T.div]) @gof.local_optimizer([T.div])
def local_div_to_inv(node): def local_div_to_inv(node):
...@@ -564,10 +568,120 @@ def local_div_to_inv(node): ...@@ -564,10 +568,120 @@ def local_div_to_inv(node):
return [T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))] return [T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))]
else: else:
return False return False
register_canonicalize(local_neg_to_mul)
register_specialize(local_mul_to_neg)
register_specialize(local_div_to_inv) register_specialize(local_div_to_inv)
@gof.local_optimizer([T.inv])
def local_inv_canon(node):
if node.op == T.inv:
return [T.pow(node.inputs[0], -1.0)]
else:
return False
register_canonicalize(local_inv_canon)
@gof.local_optimizer([T.pow])
def local_pow_canonicalize(node):
if node.op == T.pow:
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1.0):
return [T.fill(node.inputs[1], node.inputs[0])]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0.0):
#extra fills here are to make sure the size of the output stays constant.
return [T.fill(node.inputs[0], T.fill(node.inputs[1], 1.0))]
else:
return False
register_canonicalize(local_pow_canonicalize)
@gof.local_optimizer([T.pow])
def local_pow_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.pow:
#the idea here is that we have pow(x, y)
xsym = node.inputs[0]
ysym = node.inputs[1]
y = local_mul_canonizer.get_constant(ysym)
if (y is not None) \
and encompasses_broadcastable(xsym.type.broadcastable, ysym.type.broadcastable):
if N.all(y == 2.0):
return [T.sqr(xsym)]
if N.all(y == 1.0):
return [xsym]
if N.all(y == 0.0):
return [T.fill(xsym, 1.0)]
if N.all(y == 0.5):
return [T.sqrt(xsym)]
if N.all(y == -0.5):
return [T.inv(T.sqrt(xsym))]
if N.all(y == -1.0):
return [T.inv(xsym)]
if N.all(y == -2.0):
return [T.inv(T.sqr(xsym))]
else:
return False
register_specialize(local_pow_specialize)
@gof.local_optimizer([T.mul])
def local_mul_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.mul:
#the idea here is that we have pow(x, y)
neg = False
new_inputs = []
for input in node.inputs:
y = local_mul_canonizer.get_constant(input)
if N.all(y == 1.0):
continue
elif N.all(y == -1.0):
neg ^= True #toggles
elif N.all(y == 0.0):
return [input]
else:
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0:
newval = -y.flatten()[0] if neg else y.flatten()[0]
return [T.TensorConstant(T.Tensor(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval))]
if len(new_inputs) == 1:
return [-new_inputs[0]] if neg else new_inputs
else:
return [-T.mul(*new_inputs)] if neg else \
[T.mul(*new_inputs)]
else:
return False
register_specialize(local_mul_specialize)
if 0: #TODO: replace this with a c version of any InplaceDimShuffle
class _TransposeInplace(T.Op):
view_map = {0: [0]}
def make_node(self, input):
return T.Apply(self, [input],
[T.tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "_TransposeInplace"
_transpose_inplace = _TransposeInplace()
@gof.local_optimizer([T.DimShuffle([False,False],[1,0],inplace=True)])
def local_dimshuffle_transposeinplace(node):
if node.op == T.DimShuffle([False,False],[1,0],inplace=True):
return [_transpose_inplace(node.inputs[0])]
return False
register_specialize(local_dimshuffle_transposeinplace)
register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer') register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
...@@ -724,8 +838,10 @@ def constant_folding(node): ...@@ -724,8 +838,10 @@ def constant_folding(node):
register_canonicalize(constant_folding) register_canonicalize(constant_folding)
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (inplace_matrix_transpose, 'y'), (inplace_matrix_transpose, 'x')))
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
# def _math_optimizer(): # def _math_optimizer():
......
...@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase):
self.failUnless(i.shape == (2,3)) self.failUnless(i.shape == (2,3))
class T_transpose(unittest.TestCase):
def test0(self):
n = as_tensor(numpy.ones(()))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data == 1.0)
def test1(self):
n = as_tensor(numpy.ones(5))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data[0] == 1.0)
def test2(self):
n = as_tensor(numpy.ones((5,3)))
t = transpose(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0] == 1.0)
def test3(self):
"""Test transpose of tensor, inplace version"""
n = as_tensor(numpy.ones((5,3,2)))
t = inplace.transpose_inplace(n)
self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t)
tval = f(n.data)
self.failUnless(tval.shape == (2,3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0,0] == 56.0)
def test_grad(self):
verify_grad(self, inplace.transpose_inplace, [numpy.random.rand(2, 3)])
verify_grad(self, inplace.transpose_inplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def setUp(self): def setUp(self):
Subtensor.debug = False Subtensor.debug = False
...@@ -1406,179 +1356,6 @@ class t_dot(unittest.TestCase): ...@@ -1406,179 +1356,6 @@ class t_dot(unittest.TestCase):
#verify_grad(self, dot, [self.rand(), self.rand(2)]) #verify_grad(self, dot, [self.rand(), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2,5)]) #verify_grad(self, dot, [self.rand(), self.rand(2,5)])
class t_gemm(unittest.TestCase):
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor(p).type() for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = value(self.rand(2,2))
A = value(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [value(p) for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
class T_tensorfromscalar(unittest.TestCase): class T_tensorfromscalar(unittest.TestCase):
def test0(self): def test0(self):
s = scal.constant(56) s = scal.constant(56)
......
import theano.tensor as T
from ...gof import Env
import numpy
from theano.tensor.blas import *
from theano.tensor.blas import _as_scalar, _dot22, _is_real_matrix
from unittest import TestCase
from copy import copy
from theano import In, Out
from .test_basic import (_approx_eq, as_tensor, function,
compile, value, constant, inplace, eval_outputs)
class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it is supposed to."""
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor(p).type() for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = value(self.rand(2,2))
A = value(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [value(p) for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
class t_as_scalar(TestCase):
def test0(self):
"""Test that it works on scalar constants"""
a = T.constant(2.5)
b = T.constant(numpy.asarray([[[0.5]]]))
d_a = T.DimShuffle([], [])(a)
d_b = T.DimShuffle([True, True, True], [0,2,1])(b)
d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a)
self.failUnless(numpy.all(_as_scalar(a) == a))
self.failUnless(numpy.all(_as_scalar(b) == b.data), (b, _as_scalar(b)))
self.failUnless(numpy.all(_as_scalar(d_a) == a))
self.failUnless(numpy.all(_as_scalar(d_b) == b.data))
self.failUnless(numpy.all(_as_scalar(d_a2) == a))
def test1(self):
"""Test that it fails on nonscalar constants"""
a = T.constant(numpy.ones(5))
self.failUnless(None == _as_scalar(a))
self.failUnless(None == _as_scalar(T.DimShuffle([False], [0,'x'])(a)))
def test2(self):
"""Test that it works on scalar variables"""
a = T.dscalar()
d_a = T.DimShuffle([], [])(a)
d_a2 = T.DimShuffle([], ['x', 'x'])(a)
self.failUnless(_as_scalar(a) is a)
self.failUnless(_as_scalar(d_a) is a)
self.failUnless(_as_scalar(d_a2) is a)
def test3(self):
"""Test that it fails on nonscalar variables"""
a = T.dmatrix()
self.failUnless(None == _as_scalar(a))
self.failUnless(None == _as_scalar(T.DimShuffle([False, False], [0,'x', 1])(a)))
class T_real_matrix(TestCase):
def test0(self):
self.failUnless(_is_real_matrix(T.DimShuffle([False,False], [1, 0])(T.dmatrix())))
self.failUnless(not _is_real_matrix(T.DimShuffle([False], ['x', 0])(T.dvector())))
class T_gemm_opt(TestCase):
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def XYZab(self):
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
def just_gemm(self, i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
def on_fail():
for node in f.maker.env.toposort():
print 'GRAPH', node
self.fail()
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: on_fail()
if node.op == _dot22: on_fail()
g = function(i, o, mode='FAST_COMPILE')
for node in g.maker.env.nodes:
if node.op == gemm: on_fail()
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes])
if numpy.max(numpy.abs(r0[0] - r1[0])) > 1.0e-8:
self.fail()
def test0(self):
"""Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = self.XYZab()
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)])
#with transposes (transposes should be pushed through dot in canonicalize)
self.just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
self.just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
#with N multiplications instead of just one
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)])
# with > 2 terms in the overall addition
self.just_gemm([X,Y,Z,a,b], [Z + Z + T.dot(X,Y) + Z])
def test_double_gemm(self):
"""This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
self.just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
def wishlist(self):
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
def test_vector_stuff(self):
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
u,v = T.dvector(), T.dvector()
f = function([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
self.failIf(gemm in [n.op for n in f.maker.env.nodes])
f = function([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN')
self.failIf(gemm in [n.op for n in f.maker.env.nodes])
import unittest
def main(modulename):
if 0:
unittest.main()
elif 1:
module = __import__(modulename)
tests = unittest.TestLoader().loadTestsFromModule(module)
tests.debug()
else:
testcases = []
testcases.append(T_function_module)
#<testsuite boilerplate>
testloader = unittest.TestLoader()
suite = unittest.TestSuite()
for testcase in testcases:
suite.addTest(testloader.loadTestsFromTestCase(testcase))
unittest.TextTestRunner(verbosity=2).run(suite)
#</boilerplate>
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论