Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
215743e7
提交
215743e7
authored
2月 04, 2008
作者:
bergstrj@iro.umontreal.ca
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
moved blas stuff back to blas.py, moved cblas.h into blas.py, removed need for OMEGA_CBLAS_H
上级
0c572c25
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
768 行增加
和
344 行删除
+768
-344
blas.py
blas.py
+755
-128
core.py
core.py
+13
-216
没有找到文件。
blas.py
浏览文件 @
215743e7
import
numpy
import
core
import
os
,
sys
from
gof
import
PatternOptimizer
as
pattern_opt
,
OpSubOptimizer
as
op_sub
import
scipy.weave
as
weave
...
...
@@ -12,141 +11,769 @@ by this file are aimed at the goal of inserting gemm Ops in place of more
fine-grained motifs of iadd, isub, scale, and dot.
"""
_general_gemm_code
=
"""
def
cblas_header_text
():
return
"""
//#include <stddef.h>
static int mat_gemm_general(double a, const mat_t &A, const mat_t &B, double b, mat_t &C)
{
fprintf(stderr, "INFO: running mat_gemm_general
\\
n");
for (size_t i = 0; i < C.M; ++i)
#undef __BEGIN_DECLS
#undef __END_DECLS
#ifdef __cplusplus
#define __BEGIN_DECLS extern "C" {
#define __END_DECLS }
#else
#define __BEGIN_DECLS /* empty */
#define __END_DECLS /* empty */
#endif
__BEGIN_DECLS
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
__END_DECLS
"""
def
_constant
(
f
):
"""Return a function that always returns its first call value
"""
def
rval
(
*
args
,
**
kwargs
):
if
not
hasattr
(
f
,
'rval'
):
f
.
rval
=
f
(
*
args
,
**
kwargs
)
return
f
.
rval
return
rval
@_constant
def
ldflags
():
"""Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation.
Default: ['cblas','blas'], but environment variable OMEGA_BLAS_LDFLAGS overrides this.
"""
if
os
.
getenv
(
'OMEGA_BLAS_LDFLAGS'
):
return
os
.
getenv
(
'OMEGA_BLAS_LDFLAGS'
)
.
split
()
else
:
return
[
'cblas'
,
'blas'
]
def
gemm_code
(
check_ab
,
a_init
,
b_init
):
mod
=
'
%
'
return
"""
const char * error_string = NULL;
int type_num = _x->descr->type_num;
int type_size = _x->descr->elsize; // in bytes
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback;
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
%(check_ab)
s
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num)
||(_x->descr->type_num != _z->descr->type_num))
goto _dot_execute_fallback;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
for (size_t j = 0; j < C.N; ++j)
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0]
%(mod)
s type_size) || (Sx[1]
%(mod)
s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0]
%(mod)
s type_size) || (Sy[1]
%(mod)
s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0]
%(mod)
s type_size) || (Sz[1]
%(mod)
s type_size))
{
goto _dot_execute_fallback;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
C(i,j) *= b;
C(i,j) += a * cblas_ddot(A.N, &A(i,0), A.n, &B(0,j), B.m );
#define REAL float
float a =
%(a_init)
s;
float b =
%(b_init)
s;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a =
%(a_init)
s;
double b =
%(b_init)
s;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
}
return 0;
}
"""
_gemm_code_template
=
"""
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "mat_gemm input array size mismatch");
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead
\\
n");
exit(1);
}
if ((Sx[0] < 1) || (Sx[1] < 1)
|| (Sy[0] < 1) || (Sy[1] < 1)
|| (Sz[0] < 1) || (Sz[1] < 1))
{
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead
\\
n");
exit(1);
//return mat_gemm_general(a, A, B, b, C);
}
return 0; //success!
//TODO: OPTIMIZE for many special cases:
//- gemv
//- ger
//- ddot
//- others?
_dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback");
return -1;
int unit = 0;
unit |= ((Sx[1] == sizeof(
%(dtype)
s)) ? 0x0 : (Sx[0] == sizeof(
%(dtype)
s)) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == sizeof(
%(dtype)
s)) ? 0x0 : (Sy[0] == sizeof(
%(dtype)
s)) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == sizeof(
%(dtype)
s)) ? 0x0 : (Sz[0] == sizeof(
%(dtype)
s)) ? 0x1 : 0x2) << 8;
_dot_execute_fail:
if (error_string == NULL)
PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs");
return -1;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
size_t sx_0 = (Nx[0] > 1) ? Sx[0]/sizeof(
%(dtype)
s) : Nx[1];
size_t sx_1 = (Nx[1] > 1) ? Sx[1]/sizeof(
%(dtype)
s) : Nx[0];
size_t sy_0 = (Ny[0] > 1) ? Sy[0]/sizeof(
%(dtype)
s) : Ny[1];
size_t sy_1 = (Ny[1] > 1) ? Sy[1]/sizeof(
%(dtype)
s) : Ny[0];
size_t sz_0 = (Nz[0] > 1) ? Sz[0]/sizeof(
%(dtype)
s) : Nz[1];
size_t sz_1 = (Nz[1] > 1) ? Sz[1]/sizeof(
%(dtype)
s) : Nz[0];
switch(unit)
/* v 1 */
"""
%
locals
()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm
=
"""
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
case 0x000:
%(gemm)
s(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_0); break;
case 0x001:
%(gemm)
s(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_0); break;
case 0x010:
%(gemm)
s(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_0); break;
case 0x011:
%(gemm)
s(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_0); break;
case 0x100:
%(gemm)
s(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_1); break;
case 0x101:
%(gemm)
s(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_1); break;
case 0x110:
%(gemm)
s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_1); break;
case 0x111:
%(gemm)
s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_1); break;
default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead
\\
n");
exit(1);
};
/* v 1 */
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
_gemm_code
=
{
'f'
:
_gemm_code_template
%
{
'gemm'
:
'cblas_sgemm'
,
'dtype'
:
'float'
},
'd'
:
_gemm_code_template
%
{
'gemm'
:
'cblas_dgemm'
,
'dtype'
:
'double'
}}
def
_gemm_rank2
(
a
,
x
,
y
,
b
,
z
):
weave
.
inline
(
_gemm_code
[
z
.
dtype
.
char
],
[
'a'
,
'x'
,
'y'
,
'b'
,
'z'
],
headers
=
[
'"/home/bergstra/cvs/lgcm/omega/cblas.h"'
],
libraries
=
[
'mkl'
,
'm'
])
def
_gemm
(
a
,
x
,
y
,
b
,
z
):
if
len
(
x
.
shape
)
==
2
and
len
(
y
.
shape
)
==
2
:
_gemm_rank2
(
a
,
x
,
y
,
b
,
z
)
else
:
if
b
==
0.0
:
if
a
==
1.0
:
z
[:]
=
numpy
.
dot
(
x
,
y
)
elif
a
==
-
1.0
:
z
[:]
=
-
numpy
.
dot
(
x
,
y
)
else
:
z
[:]
=
a
*
numpy
.
dot
(
x
,
y
)
elif
b
==
1.0
:
if
a
==
1.0
:
z
+=
numpy
.
dot
(
x
,
y
)
elif
a
==
-
1.0
:
z
-=
numpy
.
dot
(
x
,
y
)
else
:
z
+=
a
*
numpy
.
dot
(
x
,
y
)
else
:
z
*=
b
z
+=
a
*
numpy
.
dot
(
x
,
y
)
_gdot_coefs
=
{
'f'
:
(
numpy
.
ones
((),
dtype
=
'float32'
),
numpy
.
zeros
((),
dtype
=
'float32'
)),
'd'
:
(
numpy
.
ones
(()),
numpy
.
zeros
(()))}
def
_gdot
(
x
,
y
):
a
,
b
=
_gdot_coefs
[
x
.
dtype
.
char
]
z
=
numpy
.
ndarray
((
x
.
shape
[
0
],
y
.
shape
[
1
]),
dtype
=
x
.
dtype
)
_gemm
(
a
,
x
,
y
,
b
,
z
)
return
z
class
gemm
(
core
.
omega_op
,
core
.
inplace
):
def
impl
(
z
,
a
,
x
,
y
,
b
):
_gemm
(
a
,
x
,
y
,
b
,
z
)
return
z
[:]
def
grad
(
x
,
gz
):
raise
NotImplemented
class
gdot
(
core
.
omega_op
):
impl
=
_gdot
def
grad
(
x
,
gz
):
raise
NotImplemented
#TODO: put more optimizations in here (Trac # 18)
optimizations
=
[
pattern_opt
(
(
core
.
isub_elemwise
,
'z'
,
(
core
.
dot
,
'x'
,
'y'
)),
(
gemm
,
'z'
,
-
1.0
,
'x'
,
'y'
,
1.0
)),
pattern_opt
(
(
core
.
dot
,
'x'
,
'y'
),
(
gdot
,
'x'
,
'y'
))
]
core.py
浏览文件 @
215743e7
...
...
@@ -7,14 +7,16 @@ import inspect
import
md5
import
copy
import
numpy
from
scipy
import
weave
import
gof
from
gof
import
current_mode
,
set_mode
,
build_mode
,
eval_mode
,
build_eval_mode
,
pop_mode
,
UNCOMPUTED
,
UNDEFINED
,
PythonR
import
type_spec
import
cutils
import
blas
import
numpy
from
scipy
import
weave
# __all__ = ['set_mode', 'get_mode', 'NumpyR', 'NumpyOp']
...
...
@@ -44,40 +46,7 @@ literals_id_db = {}
# see TRAC(#31)
default_input_scalar_dtype
=
'float64'
def
_constant
(
f
):
"""Return a function that always returns its first call value
"""
def
rval
(
*
args
,
**
kwargs
):
if
not
hasattr
(
f
,
'rval'
):
f
.
rval
=
f
(
*
args
,
**
kwargs
)
return
f
.
rval
return
rval
@_constant
def
_blas_headers
():
"""Return a list of strings which should be #include-ed into C files.
Default: [''], but environment variable OMEGA_CBLAS_H overrides this.
"""
envvar
=
os
.
getenv
(
'OMEGA_CBLAS_H'
)
if
envvar
is
None
:
return
[]
else
:
return
[
envvar
]
@_constant
def
_blas_libs
():
"""Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation.
Default: ['mkl','m'], but environment variable OMEGA_BLAS_LDFLAGS overrides this.
"""
if
os
.
getenv
(
'OMEGA_BLAS_LDFLAGS'
):
return
os
.
getenv
(
'OMEGA_BLAS_LDFLAGS'
)
.
split
()
else
:
return
[
'mkl'
,
'm'
]
@_constant
@blas._constant
# TODO: move this decorator to a utility script
def
_compile_dir
():
"""Return the directory in which scipy.weave should store code objects.
...
...
@@ -111,8 +80,6 @@ def _compile_dir():
sys
.
path
.
append
(
cachedir
)
return
cachedir
def
input
(
x
):
#NB:
# - automatically casting int to float seems wrong.
...
...
@@ -965,176 +932,6 @@ inv_elemwise_inplace = inv_elemwise.inplace_version()
## Dot product ##
class
blas_code
:
@staticmethod
def
gemm_xyz
(
check_ab
,
a_init
,
b_init
):
mod
=
'
%
'
return
"""
const char * error_string = NULL;
int type_num = _x->descr->type_num;
int type_size = _x->descr->elsize; // in bytes
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback;
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
%(check_ab)
s
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num)
||(_x->descr->type_num != _z->descr->type_num))
goto _dot_execute_fallback;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0]
%(mod)
s type_size) || (Sx[1]
%(mod)
s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0]
%(mod)
s type_size) || (Sy[1]
%(mod)
s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0]
%(mod)
s type_size) || (Sz[1]
%(mod)
s type_size))
{
goto _dot_execute_fallback;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a =
%(a_init)
s;
float b =
%(b_init)
s;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a =
%(a_init)
s;
double b =
%(b_init)
s;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
}
return 0; //success!
_dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback");
return -1;
_dot_execute_fail:
if (error_string == NULL)
PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs");
return -1;
/* v 1 */
"""
%
locals
()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm
=
"""
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
class
dot
(
omega_op
):
impl
=
numpy
.
dot
...
...
@@ -1145,12 +942,12 @@ class dot(omega_op):
assert
x
[
2
][
1
]
==
y
[
2
][
0
]
shape
=
(
x
[
2
][
0
],
y
[
2
][
1
])
return
(
numpy
.
ndarray
,
upcast
(
x
[
1
],
y
[
1
]),
shape
)
def
c_
headers
(
self
):
return
_blas_headers
()
def
c_
support_code
(
self
):
return
blas
.
cblas_header_text
()
def
c_libs
(
self
):
return
_blas_lib
s
()
return
blas
.
ldflag
s
()
def
c_impl
((
_x
,
_y
),
(
_z
,
)):
return
blas
_code
.
gemm_xyz
(
''
,
'1.0'
,
'0.0'
)
return
blas
.
gemm_code
(
''
,
'1.0'
,
'0.0'
)
class
gemm
(
omega_op
,
inplace
):
...
...
@@ -1181,10 +978,10 @@ class gemm(omega_op, inplace):
return
z
def
alloc
(
self
,
except_list
):
self
.
outputs
[
0
]
.
data
=
self
.
inputs
[
0
]
.
data
def
c_
headers
(
self
):
return
_blas_headers
()
def
c_
support_code
(
self
):
return
blas
.
cblas_header_text
()
def
c_libs
(
self
):
return
_blas_lib
s
()
return
blas
.
ldflag
s
()
def
c_impl
((
_zin
,
_a
,
_x
,
_y
,
_b
),
(
_z
,)):
check_ab
=
"""
{
...
...
@@ -1197,7 +994,7 @@ class gemm(omega_op, inplace):
goto _dot_execute_fallback;
}
"""
return
blas
_code
.
gemm_xyz
(
check_ab
,
return
blas
.
gemm_code
(
check_ab
,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])'
,
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])'
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论