提交 32105ec9 authored 作者: James Bergstra's avatar James Bergstra

refactored gemm-related code. added gemm-related optimizations and tests. added…

refactored gemm-related code. added gemm-related optimizations and tests. added specializations for mul
上级 f28d5cb9
......@@ -783,148 +783,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
def keep_going(exc, nav, repl_pairs):
"""WRITEME"""
pass
......@@ -1002,5 +860,3 @@ class PureThenInplaceOptimizer(Optimizer):
if 0:
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
......@@ -2,6 +2,7 @@
from basic import *
import opt
import blas
import raw_random
from raw_random import \
......
......@@ -13,7 +13,6 @@ from copy import copy
from .. import gof
from ..gof import Result, Op, utils, AbstractFunctionError, Type, Constant, Apply, Value
import blas # for gemm, dot
from .. import gradient
import elemwise
......@@ -399,6 +398,8 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala
int_types = bscalar, wscalar, iscalar, lscalar
float_types = fscalar, dscalar
int_scalar_types = int_types
float_scalar_types = float_types
fvector = Tensor('float32', (False, ))
dvector = Tensor('float64', (False, ))
......@@ -1700,6 +1701,10 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector),
#########################
# Linalg : Dot
#########################
#
# For BLAS-related ops see blas.py
#
# TODO: Dotinv should go here, Eigs, Svd, etc.
class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
......@@ -1793,249 +1798,6 @@ class Outer(Op):
return "outer"
outer = Outer()
class Gemm(Op):
"""In-place version of matrix-matrix multiplication (with accumulation):
When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b)
is similar to
b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z,
whereas the bottom form is not. Gemm works in-place on the storage
associated with z, and the L{Result} returned by Gemm has a storage that
will be aliased to the storage of the z argument. Because of this in-place
computation, an L{Apply} of this op will destroy the L{Result} z on
which it operates. (See L{DestructiveOps} for an explanation of what
destroying means in the context of theano graphs. See L{BlasLapackSupport} for
more optimized linear algebra operations.)
"""
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
z, a, x, y, b = inputs
zr, xr, yr = [set(gof.view_roots(i)) for i in z,x,y]
if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
zout[0] = z
def grad(self, (z, a, x, y, b), (gz,)):
raise NotImplementedError()
def c_support_code(self):
#return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self):
return blas.ldflags()
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = %(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = %(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (%(_zout)s != %(_z)s)
{
if (%(_zout)s)
{
Py_DECREF(%(_zout)s);
}
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
}
""" % dict(locals(), **sub)
gemm = Gemm()
pprint.assign(gemm, printing.FunctionPrinter('gemm'))
#########################
# Gradient
#########################
......
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import os, sys
import scipy.weave as weave
from ..gof import utils
"""
File: omega/blas.py
This file is in omega's core because it consists mostly of optimizations of the
graphs that can be constructed from omega/core.py. The optimizations provided
by this file are aimed at the goal of inserting gemm Ops in place of more
fine-grained motifs of iadd, isub, scale, and dot.
"""
def cblas_header_text():
"""C header for the cblas interface"""
return """
//#include <stddef.h>
#undef __BEGIN_DECLS
#undef __END_DECLS
#ifdef __cplusplus
#define __BEGIN_DECLS extern "C" {
#define __END_DECLS }
#else
#define __BEGIN_DECLS /* empty */
#define __END_DECLS /* empty */
#endif
__BEGIN_DECLS
#define MOD %
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
__END_DECLS
"""
import numpy
from ..gof import (utils, Op, Apply, view_roots, PatternSub,
InplaceOptimizer, SeqOptimizer, warn, local_optimizer)
from ..printing import pprint, FunctionPrinter
from .opt import register_specialize, out2in, insert_inplace_optimizer
import basic as T
from ..tensor import as_tensor
#NB: this clobbers the builtin 'compile' symbol
from .. import compile #to register the optimizer built by this file
from .blas_headers import cblas_header_text, blas_header_text
def blas_proto():
"""C header for the fortran blas interface"""
return """
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
@utils.memoize
def ldflags():
......@@ -819,68 +41,105 @@ def ldflags():
#print "blas linking against", rval
return rval
def gemm_code(check_ab, a_init, b_init):
mod = '%'
return """
const char * error_string = NULL;
int type_num = _x->descr->type_num;
int type_size = _x->descr->elsize; // in bytes
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
class GemmRelated(Op):
"""Base class for Gemm and Dot22
This class provides a kind of templated gemm Op.
"""
def c_support_code(self):
#return cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas_header_text() + mod_str
def c_headers(self):
# std.cout doesn't require the '%' symbol to print stuff...
# so it works much better with python's string-substitution stuff.
return ['<iostream>']
def c_libraries(self):
return ldflags()
declare_NS = """
int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback;
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
%(check_ab)s
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num)
||(_x->descr->type_num != _z->descr->type_num))
goto _dot_execute_fallback;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0; //%(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
"""
#setup_z_Nz_Sz = None
check_xyz_rank2 = """
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
"""
check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
"""
#it is not necessary that a or b have the same type as x,y,z
check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
"""
check_dims_strides = """
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size))
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
goto _dot_execute_fallback;
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
}
"""
encode_strides_in_unit = """
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
"""
compute_strides = """
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
......@@ -890,100 +149,410 @@ def gemm_code(check_ab, a_init, b_init):
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
"""
begin_switch_typenum = """
switch (type_num)
{
"""
case_float = """
case PyArray_FLOAT:
{
#define REAL float
float a = %(a_init)s;
float b = %(b_init)s;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
"""
#case_float_ab_constants = None
case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
"""
case_double = """
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = %(a_init)s;
double b = %(b_init)s;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
"""
#case_double_ab_constants = None
case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
"""
end_switch_typenum = """
}
break;
}
"""
def build_gemm_call(self):
return reduce(str.__add__, (
self.declare_NS,
self.setup_z_Nz_Sz,
self.check_xyz_rank2,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.check_dims_strides,
self.encode_strides_in_unit,
self.compute_strides,
self.begin_switch_typenum,
self.case_float,
self.case_float_ab_constants,
self.case_float_gemm,
self.case_double,
self.case_double_ab_constants,
self.case_double_gemm,
self.end_switch_typenum), '')
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b)
is similar to
return 0; //success!
_dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback");
return -1;
_dot_execute_fail:
if (error_string == NULL)
PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs");
return -1;
/* v 1 */
""" % locals()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm = """
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z,
whereas the bottom form is not. Gemm works in-place on the storage
associated with z, and the L{Result} returned by Gemm has a storage that
will be aliased to the storage of the z argument. Because of this in-place
computation, an L{Apply} of this op will destroy the L{Result} z on
which it operates. (See L{DestructiveOps} for an explanation of what
destroying means in the context of theano graphs. See L{BlasLapackSupport} for
more optimized linear algebra operations.)
"""
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
z, a, x, y, b = inputs
zr, xr, yr = [set(view_roots(i)) for i in z,x,y]
if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
zout[0] = z
setup_z_Nz_Sz = """
if (%(_zout)s != %(_z)s)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
if (%(_zout)s)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
Py_DECREF(%(_zout)s);
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
}
}
"""
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
"""
case_float_ab_constants = """
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
#undef REAL
"""
case_double_ab_constants = """
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
#undef REAL
"""
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub):
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert x.type in T.float_matrix_types #makes sure x is a matrix
assert y.type == x.type #makes sure y is a matrix
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )):
try:
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
setup_z_Nz_Sz = """
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
"""
check_ab_double_or_float = ""
case_float_ab_constants = """
float a = 1.0;
float b = 0.0;
"""
case_double_ab_constants = """
double a = 1.0;
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub):
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
_dot22 = Dot22()
@local_optimizer([T.dot])
def local_dot_to_dot22(node):
if node.op == T.dot:
return [_dot22(*node.inputs)]
else:
return False
register_specialize(local_dot_to_dot22)
def _is_a(node, op, maxclients=None):
return node.owner \
and node.owner.op == op \
and len(node.clients) <= maxclients if maxclients is not None else True
def _as_scalar(res):
"""Return None or a TensorResult whose type is in T.float_scalar_types"""
if res.owner and isinstance(res.owner.op, T.DimShuffle):
return _as_scalar(res.owner.inputs[0])
elif res.type in T.float_scalar_types:
return res
elif isinstance(res, T.Constant) and res.data.size == 1:
return res.data.flatten()[0]
else:
return None
def _as_isolated_scalar_times_matrix(res):
if _is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2:
L, R = res.owner.inputs
sL = _as_scalar(L)
sR = _as_scalar(R)
if sL is not None and R.type in T.float_matrix_types:
return (sL, R)
if sR is not None and L.type in T.float_matrix_types:
return (sR, L)
else:
scalars = []
matrices = []
for input in res.owner.inputs:
scalar_input = _as_scalar(input)
if scalar_input is not None:
scalars.append(scalar_input)
elif input.type in T.float_matrix_types:
matrices.append(input)
else:
return None
if len(matrices) == 1:
rval = (T.mul(*scalars), matrices[0])
return rval
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if _is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)]
return rval
if _is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L
if _is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#print 'GEMM 1', G, L
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
return rval
elif G is L:
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)]
#print 'GEMM 2', rval
return rval
elif 1.0 != alpha:
#at the very least, move the alpha inside the gemm
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)]
#print 'GEMM 3', G, L
return rval
if recurse_flip:
return beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False)
else:
return False
@local_optimizer([T.sub])
def local_sub_to_gemm(node):
if node.op == T.sub:
L, R = node.inputs
if L.type not in T.float_matrix_types:
return False
if R.type not in T.float_matrix_types:
return False
tmp = _as_isolated_scalar_times_matrix(L)
try:
sL, mL = tmp
except:
sL, mL = 1.0, L
tmp = _as_isolated_scalar_times_matrix(R)
try:
sR, mR = tmp
except:
sR, mR = 1.0, R
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval
return False
register_specialize(local_sub_to_gemm)
@local_optimizer([T.add])
def local_add_to_gemm(node):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
"""
if node.op == T.add:
sM_list = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
sM_list.append(tmp if tmp is not None else (1.0,input))
#print sM_list
if len(node.inputs) == 2:
sL, mL = sM_list[0]
sR, mR = sM_list[1]
return beta_L_plus_alpha_M(sL, mL, sR, mR)
else:
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
sL, mL = sM_list[i]
sR, mR = sM_list[j]
rval = beta_L_plus_alpha_M(sL, mL, sR, mR)
if rval:
assert len(rval) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))]
return rval
return False
register_specialize(local_add_to_gemm)
""" Header text for the C and Fortran BLAS interfaces.
There is no standard name or location for this header, so we just insert it ourselves into the C code
"""
def cblas_header_text():
"""C header for the cblas interface."""
return """
//#include <stddef.h>
#undef __BEGIN_DECLS
#undef __END_DECLS
#ifdef __cplusplus
#define __BEGIN_DECLS extern "C" {
#define __END_DECLS }
#else
#define __BEGIN_DECLS /* empty */
#define __END_DECLS /* empty */
#endif
__BEGIN_DECLS
#define MOD %
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
__END_DECLS
"""
def blas_header_text():
"""C header for the fortran blas interface"""
return """
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
def ____gemm_code(check_ab, a_init, b_init):
mod = '%'
return """
const char * error_string = NULL;
int type_num = _x->descr->type_num;
int type_size = _x->descr->elsize; // in bytes
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback;
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
%(check_ab)s
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num)
||(_x->descr->type_num != _z->descr->type_num))
goto _dot_execute_fallback;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size))
{
goto _dot_execute_fallback;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = %(a_init)s;
float b = %(b_init)s;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = %(a_init)s;
double b = %(b_init)s;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
}
return 0; //success!
_dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback");
return -1;
_dot_execute_fail:
if (error_string == NULL)
PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs");
return -1;
/* v 1 */
""" % locals()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm = """
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
from basic import _scal_elemwise #, _transpose_inplace
from .basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal
import elemwise
from .. import printing
......
"""Tensor optimizations addressing the ops in basic.py
"""
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
......@@ -30,29 +31,6 @@ def in2out(*local_opts, **kwargs):
**kwargs)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T.sub,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
allow_multiple_clients = True),
(T.dot, 'b', 'c'))),
(T.gemm, 'd', (T.neg, 'a'), 'b', 'c', T.constant(1.0)),
allow_multiple_clients = False)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms dot(a, b) into gemm(zeros(2)(hstack(shape(a)[:1], shape(b)[1:])), 1.0, a, b, 1.0)
# The construction of the 'gemm' node may fail if, for example, a and b are not both matrices.
dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
(T.gemm, (T.Zeros(2),
(T.stack,
(T.Subtensor([slice(0, 1)]), (T.shape, 'a')),
(T.Subtensor([slice(1, 2)]), (T.shape, 'b')))),
T.constant(1.0), 'a', 'b', T.constant(1.0)),
allow_multiple_clients = False)
def _insert_inplace_optimizer(env):
"""
......@@ -92,12 +70,6 @@ def _insert_inplace_optimizer(env):
break
insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1),
insert_inplace_optimizer,
failure_callback = gof.warn))
compile.optdb.register('inplace_opt', inplace_optimizer, 99, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__
......@@ -625,6 +597,38 @@ def local_pow_specialize(node):
return False
register_specialize(local_pow_specialize)
@gof.local_optimizer([T.mul])
def local_mul_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.mul:
#the idea here is that we have pow(x, y)
neg = False
new_inputs = []
for input in node.inputs:
y = local_mul_canonizer.get_constant(input)
if N.all(y == 1.0):
continue
elif N.all(y == -1.0):
neg ^= True #toggles
elif N.all(y == 0.0):
return [input]
else:
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0:
newval = -1 if neg else 1
return [T.TensorConstant(T.Tensor(dtype=node.outputs[0].type.dtype,
broadcastable = ()), N.asarray(newval))]
if len(new_inputs) == 1:
return [-new_inputs[0]] if neg else new_inputs
else:
return [-T.mul(*new_inputs)] if neg else \
[T.mul(*new_inputs)]
else:
return False
register_specialize(local_mul_specialize)
if 0: #TODO: replace this with a c version of any InplaceDimShuffle
class _TransposeInplace(T.Op):
view_map = {0: [0]}
......@@ -813,250 +817,10 @@ def constant_folding(node):
register_canonicalize(constant_folding)
#################
# BLAS-related
#################
import blas
class _Dot22(gof.Op):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert x.type in T.float_matrix_types #makes sure x is a matrix
assert y.type == x.type #makes sure y is a matrix
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
return gof.Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )):
try:
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
def c_support_code(self):
#return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self):
return blas.ldflags()
def c_code(self, node, name, (_x, _y), (_z, ), sub):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0;//%(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
float a = 1.0;
float b = 0.0;
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
double a = 1.0;
double b = 0.0;
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
}
""" % dict(locals(), **sub)
_dot22 = _Dot22()
@gof.local_optimizer([T.dot])
def local_dot_to_dot22(node):
if node.op == T.dot:
return [_dot22(*node.inputs)]
else:
return False
register_specialize(local_dot_to_dot22)
@gof.local_optimizer([T.sub])
def local_sub_to_gemm(node):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
"""
if node.op == T.sub:
subleft, subright = node.inputs
#EXPRESSION: subleft - subright
if subright.owner and (subright.owner.op == _dot22):
dotleft, dotright = subright.owner.inputs
return [T.gemm(subleft, -1.0, dotleft, dotright, 1.0)]
if subright.owner and (subright.owner.op == T.mul):
mulleft, mulright = subright.owner.inputs
#EXPRESSION: subleft - (mulleft * mulright)
#TODO: we actually want to get any scalar here, not necessrily a constant
mulleft_const = local_mul_canonizer.get_constant(mulleft)
if mulleft_const is not None and mulleft_const.size == 1:
mulleft_const = mulleft_const.flatten()[0]
#EXPRESSION: subleft - (mulleft_const * ?)
if mulright.owner and (mulright.owner.op == T.add):
#EXPRESSION: subleft - (mulleft_const * (? + ?))
addleft, addright = mulright.owner.inputs
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0]):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
raise NotImplementedError()
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0], inplace=True):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
transposed = addright.owner.inputs[0]
if transposed.owner and transposed.owner.op == _dot22:
x, y = transposed.owner.inputs
#EXPRESSION: subleft - (mulleft_const * (addleft + dot(x, y).T))
if addleft.owner and addleft.owner.op == _dot22:
u, v = addleft.owner.inputs
#EXPRESSION: subleft - (mulleft_const * (dot(u,v) + dot(x, y).T))
return [T.gemm(
T.gemm(subleft, -mulleft_const, y.T, x.T, 1.0),
-mulleft_const, u, v, 1.0)]
if mulright.owner and (mulright.owner.op == _dot22):
dotleft, dotright = mulright.owner.inputs
#EXPRESSION: subleft - (mulleft_const * dot(dotleft, dotright))
return [T.gemm(subleft, -mulleft_const, dotleft, dotright, 1.0)]
mulright_const = local_mul_canonizer.get_constant(mulright)
if mulright_const is not None and mulright_const.size == 1:
mulright_const = mulright_const.flatten()[0]
#EXPRESSION: subleft - (? * mulright_const)
if mulleft.owner and (mulleft.owner.op == _dot22):
dotleft, dotright = mulleft.owner.inputs
#EXPRESSION: subleft - (dot(dotleft, dotright) * mulright_const)
return [T.gemm(subleft, -mulright_const, dotleft, dotright, 1.0)]
return False
register_specialize(local_sub_to_gemm)
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (inplace_matrix_transpose, 'y'), (inplace_matrix_transpose, 'x')))
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
# def _math_optimizer():
......
......@@ -1356,179 +1356,6 @@ class t_dot(unittest.TestCase):
#verify_grad(self, dot, [self.rand(), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2,5)])
class t_gemm(unittest.TestCase):
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor(p).type() for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = value(self.rand(2,2))
A = value(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [value(p) for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
class T_tensorfromscalar(unittest.TestCase):
def test0(self):
s = scal.constant(56)
......
import theano.tensor as T
from ..gof import Env
import numpy
from theano.tensor.blas import *
from theano.tensor.blas import _as_scalar, _dot22
from unittest import TestCase
from copy import copy
from theano import In, Out
from .test_basic import (_approx_eq, as_tensor, function,
compile, value, constant, inplace, eval_outputs)
class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it is supposed to."""
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor(p).type() for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = value(self.rand(2,2))
A = value(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [value(p) for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
class t_as_scalar(TestCase):
def test0(self):
"""Test that it works on scalar constants"""
a = T.constant(2.5)
b = T.constant(numpy.asarray([[[0.5]]]))
d_a = T.DimShuffle([], [])(a)
d_b = T.DimShuffle([True, True, True], [0,2,1])(b)
d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a)
self.failUnless(numpy.all(_as_scalar(a) == a))
self.failUnless(numpy.all(_as_scalar(b) == b.data), (b, _as_scalar(b)))
self.failUnless(numpy.all(_as_scalar(d_a) == a))
self.failUnless(numpy.all(_as_scalar(d_b) == b.data))
self.failUnless(numpy.all(_as_scalar(d_a2) == a))
def test1(self):
"""Test that it fails on nonscalar constants"""
a = T.constant(numpy.ones(5))
self.failUnless(None == _as_scalar(a))
self.failUnless(None == _as_scalar(T.DimShuffle([False], [0,'x'])(a)))
def test2(self):
"""Test that it works on scalar variables"""
a = T.dscalar()
d_a = T.DimShuffle([], [])(a)
d_a2 = T.DimShuffle([], ['x', 'x'])(a)
self.failUnless(_as_scalar(a) is a)
self.failUnless(_as_scalar(d_a) is a)
self.failUnless(_as_scalar(d_a2) is a)
def test3(self):
"""Test that it fails on nonscalar variables"""
a = T.dmatrix()
self.failUnless(None == _as_scalar(a))
self.failUnless(None == _as_scalar(T.DimShuffle([False, False], [0,'x', 1])(a)))
class T_gemm_opt(TestCase):
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def XYZab(self):
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
def just_gemm(self, i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
def on_fail():
for node in f.maker.env.toposort():
print 'GRAPH', node
self.fail()
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: on_fail()
if node.op == _dot22: on_fail()
g = function(i, o, mode='FAST_COMPILE')
for node in g.maker.env.nodes:
if node.op == gemm: on_fail()
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes])
if numpy.max(numpy.abs(r0[0] - r1[0])) > 1.0e-8:
self.fail()
def test0(self):
"""Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = self.XYZab()
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)])
#with transposes (transposes should be pushed through dot in canonicalize)
self.just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
self.just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
#with N multiplications instead of just one
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)])
# with > 2 terms in the overall addition
self.just_gemm([X,Y,Z,a,b], [Z + Z + T.dot(X,Y) + Z])
def test_double_gemm(self):
"""This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
self.just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
def wishlist(self):
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论