提交 32105ec9 authored 作者: James Bergstra's avatar James Bergstra

refactored gemm-related code. added gemm-related optimizations and tests. added…

refactored gemm-related code. added gemm-related optimizations and tests. added specializations for mul
上级 f28d5cb9
...@@ -783,148 +783,6 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -783,148 +783,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out" print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
"""WRITEME""" """WRITEME"""
pass pass
...@@ -1002,5 +860,3 @@ class PureThenInplaceOptimizer(Optimizer): ...@@ -1002,5 +860,3 @@ class PureThenInplaceOptimizer(Optimizer):
if 0:
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from basic import * from basic import *
import opt import opt
import blas
import raw_random import raw_random
from raw_random import \ from raw_random import \
......
...@@ -13,7 +13,6 @@ from copy import copy ...@@ -13,7 +13,6 @@ from copy import copy
from .. import gof from .. import gof
from ..gof import Result, Op, utils, AbstractFunctionError, Type, Constant, Apply, Value from ..gof import Result, Op, utils, AbstractFunctionError, Type, Constant, Apply, Value
import blas # for gemm, dot
from .. import gradient from .. import gradient
import elemwise import elemwise
...@@ -399,6 +398,8 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala ...@@ -399,6 +398,8 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala
int_types = bscalar, wscalar, iscalar, lscalar int_types = bscalar, wscalar, iscalar, lscalar
float_types = fscalar, dscalar float_types = fscalar, dscalar
int_scalar_types = int_types
float_scalar_types = float_types
fvector = Tensor('float32', (False, )) fvector = Tensor('float32', (False, ))
dvector = Tensor('float64', (False, )) dvector = Tensor('float64', (False, ))
...@@ -1700,6 +1701,10 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector), ...@@ -1700,6 +1701,10 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector),
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
#
# For BLAS-related ops see blas.py
#
# TODO: Dotinv should go here, Eigs, Svd, etc.
class Dot(Op): class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products. """Compute matrix-matrix, matrix-vector products and vector inner-products.
...@@ -1793,249 +1798,6 @@ class Outer(Op): ...@@ -1793,249 +1798,6 @@ class Outer(Op):
return "outer" return "outer"
outer = Outer() outer = Outer()
class Gemm(Op):
"""In-place version of matrix-matrix multiplication (with accumulation):
When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b)
is similar to
b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z,
whereas the bottom form is not. Gemm works in-place on the storage
associated with z, and the L{Result} returned by Gemm has a storage that
will be aliased to the storage of the z argument. Because of this in-place
computation, an L{Apply} of this op will destroy the L{Result} z on
which it operates. (See L{DestructiveOps} for an explanation of what
destroying means in the context of theano graphs. See L{BlasLapackSupport} for
more optimized linear algebra operations.)
"""
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
z, a, x, y, b = inputs
zr, xr, yr = [set(gof.view_roots(i)) for i in z,x,y]
if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
zout[0] = z
def grad(self, (z, a, x, y, b), (gz,)):
raise NotImplementedError()
def c_support_code(self):
#return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self):
return blas.ldflags()
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = %(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = %(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (%(_zout)s != %(_z)s)
{
if (%(_zout)s)
{
Py_DECREF(%(_zout)s);
}
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
}
""" % dict(locals(), **sub)
gemm = Gemm()
pprint.assign(gemm, printing.FunctionPrinter('gemm'))
######################### #########################
# Gradient # Gradient
######################### #########################
......
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import os, sys import os, sys
import scipy.weave as weave import numpy
from ..gof import utils
from ..gof import (utils, Op, Apply, view_roots, PatternSub,
InplaceOptimizer, SeqOptimizer, warn, local_optimizer)
""" from ..printing import pprint, FunctionPrinter
File: omega/blas.py from .opt import register_specialize, out2in, insert_inplace_optimizer
This file is in omega's core because it consists mostly of optimizations of the import basic as T
graphs that can be constructed from omega/core.py. The optimizations provided from ..tensor import as_tensor
by this file are aimed at the goal of inserting gemm Ops in place of more
fine-grained motifs of iadd, isub, scale, and dot. #NB: this clobbers the builtin 'compile' symbol
""" from .. import compile #to register the optimizer built by this file
def cblas_header_text(): from .blas_headers import cblas_header_text, blas_header_text
"""C header for the cblas interface"""
return """
//#include <stddef.h>
#undef __BEGIN_DECLS
#undef __END_DECLS
#ifdef __cplusplus
#define __BEGIN_DECLS extern "C" {
#define __END_DECLS }
#else
#define __BEGIN_DECLS /* empty */
#define __END_DECLS /* empty */
#endif
__BEGIN_DECLS
#define MOD %
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
__END_DECLS
"""
def blas_proto():
"""C header for the fortran blas interface"""
return """
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
@utils.memoize @utils.memoize
def ldflags(): def ldflags():
...@@ -819,68 +41,105 @@ def ldflags(): ...@@ -819,68 +41,105 @@ def ldflags():
#print "blas linking against", rval #print "blas linking against", rval
return rval return rval
def gemm_code(check_ab, a_init, b_init): class GemmRelated(Op):
mod = '%' """Base class for Gemm and Dot22
return """
const char * error_string = NULL; This class provides a kind of templated gemm Op.
"""
int type_num = _x->descr->type_num; def c_support_code(self):
int type_size = _x->descr->elsize; // in bytes #return cblas_header_text()
mod_str = """
npy_intp* Nx = _x->dimensions; #ifndef MOD
npy_intp* Ny = _y->dimensions; #define MOD %
npy_intp* Nz = _z->dimensions; #endif
"""
npy_intp* Sx = _x->strides; return blas_header_text() + mod_str
npy_intp* Sy = _y->strides; def c_headers(self):
npy_intp* Sz = _z->strides; # std.cout doesn't require the '%' symbol to print stuff...
# so it works much better with python's string-substitution stuff.
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; return ['<iostream>']
def c_libraries(self):
return ldflags()
declare_NS = """
int unit = 0; int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback; int type_num = %(_x)s->descr->type_num;
if (_y->nd != 2) goto _dot_execute_fallback; int type_size = %(_x)s->descr->elsize; // in bytes
if (_z->nd != 2) goto _dot_execute_fallback;
npy_intp* Nx = %(_x)s->dimensions;
%(check_ab)s npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions;
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT)) npy_intp* Sx = %(_x)s->strides;
goto _dot_execute_fallback; npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0; //%(_z)s->strides;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT)) //strides for x, y, z in dimensions 0, 1
goto _dot_execute_fallback; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
"""
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT)) #setup_z_Nz_Sz = None
goto _dot_execute_fallback;
check_xyz_rank2 = """
if ((_x->descr->type_num != _y->descr->type_num) if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
||(_x->descr->type_num != _z->descr->type_num)) if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
goto _dot_execute_fallback; if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
"""
check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
"""
#it is not necessary that a or b have the same type as x,y,z
check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
"""
check_dims_strides = """
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{ {
error_string = "Input dimensions do not agree"; PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
goto _dot_execute_fail; %(fail)s;
} }
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size) if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size) || (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size)) || (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{ {
goto _dot_execute_fallback; PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
} }
"""
encode_strides_in_unit = """
/* /*
encode the stride structure of _x,_y,_z into a single integer encode the stride structure of _x,_y,_z into a single integer
*/ */
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0; unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4; unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8; unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
"""
compute_strides = """
/* create appropriate strides for malformed matrices that are row or column /* create appropriate strides for malformed matrices that are row or column
* vectors * vectors
*/ */
...@@ -890,100 +149,410 @@ def gemm_code(check_ab, a_init, b_init): ...@@ -890,100 +149,410 @@ def gemm_code(check_ab, a_init, b_init):
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0]; sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1]; sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0]; sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
"""
begin_switch_typenum = """
switch (type_num) switch (type_num)
{ {
"""
case_float = """
case PyArray_FLOAT: case PyArray_FLOAT:
{ {
#define REAL float """
float a = %(a_init)s;
float b = %(b_init)s; #case_float_ab_constants = None
float* x = (float*)PyArray_DATA(_x); case_float_gemm = """
float* y = (float*)PyArray_DATA(_y); float* x = (float*)PyArray_DATA(%(_x)s);
float* z = (float*)PyArray_DATA(_z); float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit) switch(unit)
{ {
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: goto _dot_execute_fallback; default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
}; };
#undef REAL """
case_double = """
} }
break; break;
case PyArray_DOUBLE: case PyArray_DOUBLE:
{ {
#define REAL double """
double a = %(a_init)s;
double b = %(b_init)s; #case_double_ab_constants = None
double* x = (double*)PyArray_DATA(_x); case_double_gemm = """
double* y = (double*)PyArray_DATA(_y); double* x = (double*)PyArray_DATA(%(_x)s);
double* z = (double*)PyArray_DATA(_z); double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit) switch(unit)
{ {
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: goto _dot_execute_fallback; default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
}; };
#undef REAL """
end_switch_typenum = """
} }
break; break;
} }
"""
def build_gemm_call(self):
return reduce(str.__add__, (
self.declare_NS,
self.setup_z_Nz_Sz,
self.check_xyz_rank2,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.check_dims_strides,
self.encode_strides_in_unit,
self.compute_strides,
self.begin_switch_typenum,
self.case_float,
self.case_float_ab_constants,
self.case_float_gemm,
self.case_double,
self.case_double_ab_constants,
self.case_double_gemm,
self.end_switch_typenum), '')
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b)
is similar to
return 0; //success! b*z + a*dot(x,y)
_dot_execute_fallback: The difference between the two is that the top form is destructive on z,
PyErr_SetString(PyExc_NotImplementedError, whereas the bottom form is not. Gemm works in-place on the storage
"dot->execute() fallback"); associated with z, and the L{Result} returned by Gemm has a storage that
return -1; will be aliased to the storage of the z argument. Because of this in-place
computation, an L{Apply} of this op will destroy the L{Result} z on
_dot_execute_fail: which it operates. (See L{DestructiveOps} for an explanation of what
if (error_string == NULL) destroying means in the context of theano graphs. See L{BlasLapackSupport} for
PyErr_SetString(PyExc_ValueError, more optimized linear algebra operations.)
"dot->execute() cant run on these inputs");
return -1; """
E_rank = 'gemm only works for rank 2'
/* v 1 */ E_scalar = 'gemm requires scalar argument'
""" % locals() E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
# currently unused, preferring the fallback method (throwing def make_node(self, *inputs):
# NotImplementedError) for when gemm won't work. inputs = map(as_tensor, inputs)
_templated_memaligned_gemm = """ if len(inputs) != 5:
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz> raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
int general_gemm(int zM, int zN, int xN,. z, a, x, y, b = inputs
Ta a, zr, xr, yr = [set(view_roots(i)) for i in z,x,y]
Tx * x, int xm, int xn, if zr.intersection(xr):
Tx * y, int ym, int yn, raise ValueError(Gemm.E_z_uniq, (z, x))
Tb b, if zr.intersection(yr):
Tz * z, int zm, int zn) raise ValueError(Gemm.E_z_uniq, (z, y))
{ bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
for (int i = 0; i < zM; ++i) if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
{ if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
for (int j = 0; j < zN; ++j) if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
zout[0] = z
setup_z_Nz_Sz = """
if (%(_zout)s != %(_z)s)
{ {
Tz zij = 0.0; if (%(_zout)s)
for (int k = 0; k < xN; ++k)
{ {
zij += x[i*xm+k*xn] * y[k*ym+j*yn]; Py_DECREF(%(_zout)s);
} }
z[i * zm + j * zn] *= b; %(_zout)s = %(_z)s;
z[i * zm + j * zn] += a * zij; Py_INCREF(%(_zout)s);
} }
} Nz = %(_z)s->dimensions;
} Sz = %(_z)s->strides;
""" """
case_float_ab_constants = """
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
#undef REAL
"""
case_double_ab_constants = """
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
#undef REAL
"""
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub):
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert x.type in T.float_matrix_types #makes sure x is a matrix
assert y.type == x.type #makes sure y is a matrix
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )):
try:
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
setup_z_Nz_Sz = """
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
"""
check_ab_double_or_float = ""
case_float_ab_constants = """
float a = 1.0;
float b = 0.0;
"""
case_double_ab_constants = """
double a = 1.0;
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub):
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
_dot22 = Dot22()
@local_optimizer([T.dot])
def local_dot_to_dot22(node):
if node.op == T.dot:
return [_dot22(*node.inputs)]
else:
return False
register_specialize(local_dot_to_dot22)
def _is_a(node, op, maxclients=None):
return node.owner \
and node.owner.op == op \
and len(node.clients) <= maxclients if maxclients is not None else True
def _as_scalar(res):
"""Return None or a TensorResult whose type is in T.float_scalar_types"""
if res.owner and isinstance(res.owner.op, T.DimShuffle):
return _as_scalar(res.owner.inputs[0])
elif res.type in T.float_scalar_types:
return res
elif isinstance(res, T.Constant) and res.data.size == 1:
return res.data.flatten()[0]
else:
return None
def _as_isolated_scalar_times_matrix(res):
if _is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2:
L, R = res.owner.inputs
sL = _as_scalar(L)
sR = _as_scalar(R)
if sL is not None and R.type in T.float_matrix_types:
return (sL, R)
if sR is not None and L.type in T.float_matrix_types:
return (sR, L)
else:
scalars = []
matrices = []
for input in res.owner.inputs:
scalar_input = _as_scalar(input)
if scalar_input is not None:
scalars.append(scalar_input)
elif input.type in T.float_matrix_types:
matrices.append(input)
else:
return None
if len(matrices) == 1:
rval = (T.mul(*scalars), matrices[0])
return rval
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if _is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)]
return rval
if _is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L
if _is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#print 'GEMM 1', G, L
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
return rval
elif G is L:
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)]
#print 'GEMM 2', rval
return rval
elif 1.0 != alpha:
#at the very least, move the alpha inside the gemm
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)]
#print 'GEMM 3', G, L
return rval
if recurse_flip:
return beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False)
else:
return False
@local_optimizer([T.sub])
def local_sub_to_gemm(node):
if node.op == T.sub:
L, R = node.inputs
if L.type not in T.float_matrix_types:
return False
if R.type not in T.float_matrix_types:
return False
tmp = _as_isolated_scalar_times_matrix(L)
try:
sL, mL = tmp
except:
sL, mL = 1.0, L
tmp = _as_isolated_scalar_times_matrix(R)
try:
sR, mR = tmp
except:
sR, mR = 1.0, R
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval
return False
register_specialize(local_sub_to_gemm)
@local_optimizer([T.add])
def local_add_to_gemm(node):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
"""
if node.op == T.add:
sM_list = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
sM_list.append(tmp if tmp is not None else (1.0,input))
#print sM_list
if len(node.inputs) == 2:
sL, mL = sM_list[0]
sR, mR = sM_list[1]
return beta_L_plus_alpha_M(sL, mL, sR, mR)
else:
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
sL, mL = sM_list[i]
sR, mR = sM_list[j]
rval = beta_L_plus_alpha_M(sL, mL, sR, mR)
if rval:
assert len(rval) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))]
return rval
return False
register_specialize(local_add_to_gemm)
""" Header text for the C and Fortran BLAS interfaces.
There is no standard name or location for this header, so we just insert it ourselves into the C code
"""
def cblas_header_text():
"""C header for the cblas interface."""
return """
//#include <stddef.h>
#undef __BEGIN_DECLS
#undef __END_DECLS
#ifdef __cplusplus
#define __BEGIN_DECLS extern "C" {
#define __END_DECLS }
#else
#define __BEGIN_DECLS /* empty */
#define __END_DECLS /* empty */
#endif
__BEGIN_DECLS
#define MOD %
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
__END_DECLS
"""
def blas_header_text():
"""C header for the fortran blas interface"""
return """
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
def ____gemm_code(check_ab, a_init, b_init):
mod = '%'
return """
const char * error_string = NULL;
int type_num = _x->descr->type_num;
int type_size = _x->descr->elsize; // in bytes
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback;
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
%(check_ab)s
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num)
||(_x->descr->type_num != _z->descr->type_num))
goto _dot_execute_fallback;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size))
{
goto _dot_execute_fallback;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = %(a_init)s;
float b = %(b_init)s;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = %(a_init)s;
double b = %(b_init)s;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
}
return 0; //success!
_dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback");
return -1;
_dot_execute_fail:
if (error_string == NULL)
PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs");
return -1;
/* v 1 */
""" % locals()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm = """
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
from basic import _scal_elemwise #, _transpose_inplace from .basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal from .. import scalar as scal
import elemwise import elemwise
from .. import printing from .. import printing
......
"""Tensor optimizations addressing the ops in basic.py
"""
# TODO: intelligent merge for mul/add # TODO: intelligent merge for mul/add
# TODO: 0*x -> 0 # TODO: 0*x -> 0
...@@ -30,29 +31,6 @@ def in2out(*local_opts, **kwargs): ...@@ -30,29 +31,6 @@ def in2out(*local_opts, **kwargs):
**kwargs) **kwargs)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T.sub,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
allow_multiple_clients = True),
(T.dot, 'b', 'c'))),
(T.gemm, 'd', (T.neg, 'a'), 'b', 'c', T.constant(1.0)),
allow_multiple_clients = False)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms dot(a, b) into gemm(zeros(2)(hstack(shape(a)[:1], shape(b)[1:])), 1.0, a, b, 1.0)
# The construction of the 'gemm' node may fail if, for example, a and b are not both matrices.
dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
(T.gemm, (T.Zeros(2),
(T.stack,
(T.Subtensor([slice(0, 1)]), (T.shape, 'a')),
(T.Subtensor([slice(1, 2)]), (T.shape, 'b')))),
T.constant(1.0), 'a', 'b', T.constant(1.0)),
allow_multiple_clients = False)
def _insert_inplace_optimizer(env): def _insert_inplace_optimizer(env):
""" """
...@@ -92,12 +70,6 @@ def _insert_inplace_optimizer(env): ...@@ -92,12 +70,6 @@ def _insert_inplace_optimizer(env):
break break
insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer) insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1),
insert_inplace_optimizer,
failure_callback = gof.warn))
compile.optdb.register('inplace_opt', inplace_optimizer, 99, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
...@@ -625,6 +597,38 @@ def local_pow_specialize(node): ...@@ -625,6 +597,38 @@ def local_pow_specialize(node):
return False return False
register_specialize(local_pow_specialize) register_specialize(local_pow_specialize)
@gof.local_optimizer([T.mul])
def local_mul_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.mul:
#the idea here is that we have pow(x, y)
neg = False
new_inputs = []
for input in node.inputs:
y = local_mul_canonizer.get_constant(input)
if N.all(y == 1.0):
continue
elif N.all(y == -1.0):
neg ^= True #toggles
elif N.all(y == 0.0):
return [input]
else:
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0:
newval = -1 if neg else 1
return [T.TensorConstant(T.Tensor(dtype=node.outputs[0].type.dtype,
broadcastable = ()), N.asarray(newval))]
if len(new_inputs) == 1:
return [-new_inputs[0]] if neg else new_inputs
else:
return [-T.mul(*new_inputs)] if neg else \
[T.mul(*new_inputs)]
else:
return False
register_specialize(local_mul_specialize)
if 0: #TODO: replace this with a c version of any InplaceDimShuffle if 0: #TODO: replace this with a c version of any InplaceDimShuffle
class _TransposeInplace(T.Op): class _TransposeInplace(T.Op):
view_map = {0: [0]} view_map = {0: [0]}
...@@ -813,250 +817,10 @@ def constant_folding(node): ...@@ -813,250 +817,10 @@ def constant_folding(node):
register_canonicalize(constant_folding) register_canonicalize(constant_folding)
################# inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
# BLAS-related local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y')),
################# (T.dot, (inplace_matrix_transpose, 'y'), (inplace_matrix_transpose, 'x')))
import blas register_canonicalize(local_transposed_dot, name='local_transposed_dot')
class _Dot22(gof.Op):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert x.type in T.float_matrix_types #makes sure x is a matrix
assert y.type == x.type #makes sure y is a matrix
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
return gof.Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )):
try:
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
def c_support_code(self):
#return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self):
return blas.ldflags()
def c_code(self, node, name, (_x, _y), (_z, ), sub):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0;//%(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
float a = 1.0;
float b = 0.0;
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
double a = 1.0;
double b = 0.0;
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
}
""" % dict(locals(), **sub)
_dot22 = _Dot22()
@gof.local_optimizer([T.dot])
def local_dot_to_dot22(node):
if node.op == T.dot:
return [_dot22(*node.inputs)]
else:
return False
register_specialize(local_dot_to_dot22)
@gof.local_optimizer([T.sub])
def local_sub_to_gemm(node):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
"""
if node.op == T.sub:
subleft, subright = node.inputs
#EXPRESSION: subleft - subright
if subright.owner and (subright.owner.op == _dot22):
dotleft, dotright = subright.owner.inputs
return [T.gemm(subleft, -1.0, dotleft, dotright, 1.0)]
if subright.owner and (subright.owner.op == T.mul):
mulleft, mulright = subright.owner.inputs
#EXPRESSION: subleft - (mulleft * mulright)
#TODO: we actually want to get any scalar here, not necessrily a constant
mulleft_const = local_mul_canonizer.get_constant(mulleft)
if mulleft_const is not None and mulleft_const.size == 1:
mulleft_const = mulleft_const.flatten()[0]
#EXPRESSION: subleft - (mulleft_const * ?)
if mulright.owner and (mulright.owner.op == T.add):
#EXPRESSION: subleft - (mulleft_const * (? + ?))
addleft, addright = mulright.owner.inputs
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0]):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
raise NotImplementedError()
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0], inplace=True):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
transposed = addright.owner.inputs[0]
if transposed.owner and transposed.owner.op == _dot22:
x, y = transposed.owner.inputs
#EXPRESSION: subleft - (mulleft_const * (addleft + dot(x, y).T))
if addleft.owner and addleft.owner.op == _dot22:
u, v = addleft.owner.inputs
#EXPRESSION: subleft - (mulleft_const * (dot(u,v) + dot(x, y).T))
return [T.gemm(
T.gemm(subleft, -mulleft_const, y.T, x.T, 1.0),
-mulleft_const, u, v, 1.0)]
if mulright.owner and (mulright.owner.op == _dot22):
dotleft, dotright = mulright.owner.inputs
#EXPRESSION: subleft - (mulleft_const * dot(dotleft, dotright))
return [T.gemm(subleft, -mulleft_const, dotleft, dotright, 1.0)]
mulright_const = local_mul_canonizer.get_constant(mulright)
if mulright_const is not None and mulright_const.size == 1:
mulright_const = mulright_const.flatten()[0]
#EXPRESSION: subleft - (? * mulright_const)
if mulleft.owner and (mulleft.owner.op == _dot22):
dotleft, dotright = mulleft.owner.inputs
#EXPRESSION: subleft - (dot(dotleft, dotright) * mulright_const)
return [T.gemm(subleft, -mulright_const, dotleft, dotright, 1.0)]
return False
register_specialize(local_sub_to_gemm)
# def _math_optimizer(): # def _math_optimizer():
......
...@@ -1356,179 +1356,6 @@ class t_dot(unittest.TestCase): ...@@ -1356,179 +1356,6 @@ class t_dot(unittest.TestCase):
#verify_grad(self, dot, [self.rand(), self.rand(2)]) #verify_grad(self, dot, [self.rand(), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2,5)]) #verify_grad(self, dot, [self.rand(), self.rand(2,5)])
class t_gemm(unittest.TestCase):
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor(p).type() for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = value(self.rand(2,2))
A = value(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [value(p) for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
class T_tensorfromscalar(unittest.TestCase): class T_tensorfromscalar(unittest.TestCase):
def test0(self): def test0(self):
s = scal.constant(56) s = scal.constant(56)
......
import theano.tensor as T
from ..gof import Env
import numpy
from theano.tensor.blas import *
from theano.tensor.blas import _as_scalar, _dot22
from unittest import TestCase
from copy import copy
from theano import In, Out
from .test_basic import (_approx_eq, as_tensor, function,
compile, value, constant, inplace, eval_outputs)
class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it is supposed to."""
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor(p).type() for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = value(self.rand(2,2))
A = value(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [value(p) for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
class t_as_scalar(TestCase):
def test0(self):
"""Test that it works on scalar constants"""
a = T.constant(2.5)
b = T.constant(numpy.asarray([[[0.5]]]))
d_a = T.DimShuffle([], [])(a)
d_b = T.DimShuffle([True, True, True], [0,2,1])(b)
d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a)
self.failUnless(numpy.all(_as_scalar(a) == a))
self.failUnless(numpy.all(_as_scalar(b) == b.data), (b, _as_scalar(b)))
self.failUnless(numpy.all(_as_scalar(d_a) == a))
self.failUnless(numpy.all(_as_scalar(d_b) == b.data))
self.failUnless(numpy.all(_as_scalar(d_a2) == a))
def test1(self):
"""Test that it fails on nonscalar constants"""
a = T.constant(numpy.ones(5))
self.failUnless(None == _as_scalar(a))
self.failUnless(None == _as_scalar(T.DimShuffle([False], [0,'x'])(a)))
def test2(self):
"""Test that it works on scalar variables"""
a = T.dscalar()
d_a = T.DimShuffle([], [])(a)
d_a2 = T.DimShuffle([], ['x', 'x'])(a)
self.failUnless(_as_scalar(a) is a)
self.failUnless(_as_scalar(d_a) is a)
self.failUnless(_as_scalar(d_a2) is a)
def test3(self):
"""Test that it fails on nonscalar variables"""
a = T.dmatrix()
self.failUnless(None == _as_scalar(a))
self.failUnless(None == _as_scalar(T.DimShuffle([False, False], [0,'x', 1])(a)))
class T_gemm_opt(TestCase):
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def XYZab(self):
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
def just_gemm(self, i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
def on_fail():
for node in f.maker.env.toposort():
print 'GRAPH', node
self.fail()
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: on_fail()
if node.op == _dot22: on_fail()
g = function(i, o, mode='FAST_COMPILE')
for node in g.maker.env.nodes:
if node.op == gemm: on_fail()
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes])
if numpy.max(numpy.abs(r0[0] - r1[0])) > 1.0e-8:
self.fail()
def test0(self):
"""Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = self.XYZab()
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)])
#with transposes (transposes should be pushed through dot in canonicalize)
self.just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
self.just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
#with N multiplications instead of just one
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)])
# with > 2 terms in the overall addition
self.just_gemm([X,Y,Z,a,b], [Z + Z + T.dot(X,Y) + Z])
def test_double_gemm(self):
"""This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
self.just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
def wishlist(self):
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论