提交 c7aa3b75 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for theano/tensor/blas_headers.py

上级 2b0880dd
...@@ -79,8 +79,10 @@ def detect_macos_sdot_bug(): ...@@ -79,8 +79,10 @@ def detect_macos_sdot_bug():
""") """)
_logger.debug('Trying to compile and run test case.') _logger.debug('Trying to compile and run test case.')
compilation_ok, run_ok = GCC_compiler.try_compile_tmp(test_code, compilation_ok, run_ok = GCC_compiler.try_compile_tmp(
tmp_prefix='detect_macos_sdot_bug_', flags=flags, try_run=True) test_code,
tmp_prefix='detect_macos_sdot_bug_',
flags=flags, try_run=True)
detect_macos_sdot_bug.tested = True detect_macos_sdot_bug.tested = True
# If compilation failed, we consider there is a bug, # If compilation failed, we consider there is a bug,
...@@ -124,13 +126,13 @@ def detect_macos_sdot_bug(): ...@@ -124,13 +126,13 @@ def detect_macos_sdot_bug():
_logger.debug('Trying to compile and run tentative workaround.') _logger.debug('Trying to compile and run tentative workaround.')
compilation_fix_ok, run_fix_ok = GCC_compiler.try_compile_tmp( compilation_fix_ok, run_fix_ok = GCC_compiler.try_compile_tmp(
test_fix_code, test_fix_code,
tmp_prefix='detect_macos_sdot_bug_testfix_', tmp_prefix='detect_macos_sdot_bug_testfix_',
flags=flags, flags=flags,
try_run=True) try_run=True)
_logger.info("Status of tentative fix -- compilation OK: %s, works: %s", _logger.info("Status of tentative fix -- compilation OK: %s, works: %s",
compilation_fix_ok, run_fix_ok) compilation_fix_ok, run_fix_ok)
detect_macos_sdot_bug.fix_works = run_fix_ok detect_macos_sdot_bug.fix_works = run_fix_ok
return detect_macos_sdot_bug.present return detect_macos_sdot_bug.present
...@@ -224,39 +226,39 @@ def cblas_header_text(): ...@@ -224,39 +226,39 @@ def cblas_header_text():
* =========================================================================== * ===========================================================================
*/ */
/* /*
* Routines with standard 4 prefixes (s, d, c, z) * Routines with standard 4 prefixes (s, d, c, z)
*/ */
void cblas_sswap(const int N, float *X, const int incX, void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY); float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX, void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY); float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X, void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY); const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX, void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY); double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX, void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY); double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X, void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY); const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX, void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY); void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX, void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY); void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X, void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY); const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX, void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY); void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX, void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY); void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X, void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY); const int incX, void *Y, const int incY);
/* /*
* Routines with S and D prefix only * Routines with S and D prefix only
*/ */
void cblas_srotg(float *a, float *b, float *c, float *s); void cblas_srotg(float *a, float *b, float *c, float *s);
...@@ -274,7 +276,7 @@ def cblas_header_text(): ...@@ -274,7 +276,7 @@ def cblas_header_text():
double *Y, const int incY, const double *P); double *Y, const int incY, const double *P);
/* /*
* Routines with S D C Z CS and ZD prefixes * Routines with S D C Z CS and ZD prefixes
*/ */
void cblas_sscal(const int N, const float alpha, float *X, const int incX); void cblas_sscal(const int N, const float alpha, float *X, const int incX);
...@@ -290,7 +292,7 @@ def cblas_header_text(): ...@@ -290,7 +292,7 @@ def cblas_header_text():
* =========================================================================== * ===========================================================================
*/ */
/* /*
* Routines with standard 4 prefixes (S, D, C, Z) * Routines with standard 4 prefixes (S, D, C, Z)
*/ */
void cblas_sgemv(const enum CBLAS_ORDER order, void cblas_sgemv(const enum CBLAS_ORDER order,
...@@ -305,11 +307,11 @@ def cblas_header_text(): ...@@ -305,11 +307,11 @@ def cblas_header_text():
const int incX, const float beta, float *Y, const int incY); const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, const int N, const float *A, const int lda,
float *X, const int incX); float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda, const int N, const int K, const float *A, const int lda,
float *X, const int incX); float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
...@@ -338,11 +340,11 @@ def cblas_header_text(): ...@@ -338,11 +340,11 @@ def cblas_header_text():
const int incX, const double beta, double *Y, const int incY); const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, const int N, const double *A, const int lda,
double *X, const int incX); double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda, const int N, const int K, const double *A, const int lda,
double *X, const int incX); double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
...@@ -371,11 +373,11 @@ def cblas_header_text(): ...@@ -371,11 +373,11 @@ def cblas_header_text():
const int incX, const void *beta, void *Y, const int incY); const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, const int N, const void *A, const int lda,
void *X, const int incX); void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda, const int N, const int K, const void *A, const int lda,
void *X, const int incX); void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
...@@ -404,11 +406,11 @@ def cblas_header_text(): ...@@ -404,11 +406,11 @@ def cblas_header_text():
const int incX, const void *beta, void *Y, const int incY); const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, const int N, const void *A, const int lda,
void *X, const int incX); void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda, const int N, const int K, const void *A, const int lda,
void *X, const int incX); void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
...@@ -426,7 +428,7 @@ def cblas_header_text(): ...@@ -426,7 +428,7 @@ def cblas_header_text():
const int N, const void *Ap, void *X, const int incX); const int N, const void *Ap, void *X, const int incX);
/* /*
* Routines with S and D prefixes only * Routines with S and D prefixes only
*/ */
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
...@@ -488,7 +490,7 @@ def cblas_header_text(): ...@@ -488,7 +490,7 @@ def cblas_header_text():
const int incX, const double *Y, const int incY, double *A); const int incX, const double *Y, const int incY, double *A);
/* /*
* Routines with C and Z prefixes only * Routines with C and Z prefixes only
*/ */
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
...@@ -559,7 +561,7 @@ def cblas_header_text(): ...@@ -559,7 +561,7 @@ def cblas_header_text():
* =========================================================================== * ===========================================================================
*/ */
/* /*
* Routines with standard 4 prefixes (S, D, C, Z) * Routines with standard 4 prefixes (S, D, C, Z)
*/ */
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
...@@ -683,7 +685,7 @@ def cblas_header_text(): ...@@ -683,7 +685,7 @@ def cblas_header_text():
void *B, const int ldb); void *B, const int ldb);
/* /*
* Routines with prefixes C and Z only * Routines with prefixes C and Z only
*/ */
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
...@@ -737,7 +739,7 @@ def blas_header_text(): ...@@ -737,7 +739,7 @@ def blas_header_text():
/* Single Precision */ /* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *); void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *); void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *); void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *); void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*); void sswap_( const int*, float *, const int*, float *, const int*);
...@@ -754,7 +756,7 @@ def blas_header_text(): ...@@ -754,7 +756,7 @@ def blas_header_text():
/* Double Precision */ /* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *); void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *); void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *); void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *); void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*); void dswap_( const int*, double *, const int*, double *, const int*);
...@@ -816,8 +818,8 @@ def blas_header_text(): ...@@ -816,8 +818,8 @@ def blas_header_text():
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*); void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*); void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*); void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *); void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *); void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*); void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */ /* Double Precision */
...@@ -835,8 +837,8 @@ def blas_header_text(): ...@@ -835,8 +837,8 @@ def blas_header_text():
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*); void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*); void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*); void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *); void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *); void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*); void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */ /* Single Complex Precision */
...@@ -996,15 +998,15 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -996,15 +998,15 @@ def ____gemm_code(check_ab, a_init, b_init):
%(check_ab)s %(check_ab)s
if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(_x)->type_num != NPY_FLOAT)) && (PyArray_DESCR(_x)->type_num != NPY_FLOAT))
goto _dot_execute_fallback; goto _dot_execute_fallback;
if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(_y)->type_num != NPY_FLOAT)) && (PyArray_DESCR(_y)->type_num != NPY_FLOAT))
goto _dot_execute_fallback; goto _dot_execute_fallback;
if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(_y)->type_num != NPY_FLOAT)) && (PyArray_DESCR(_y)->type_num != NPY_FLOAT))
goto _dot_execute_fallback; goto _dot_execute_fallback;
...@@ -1098,13 +1100,13 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1098,13 +1100,13 @@ def ____gemm_code(check_ab, a_init, b_init):
return 0; //success! return 0; //success!
_dot_execute_fallback: _dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError, PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback"); "dot->execute() fallback");
return -1; return -1;
_dot_execute_fail: _dot_execute_fail:
if (error_string == NULL) if (error_string == NULL)
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs"); "dot->execute() cant run on these inputs");
return -1; return -1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论