提交 f00420e6 authored 作者: notoraptor's avatar notoraptor

Add fallback for [sd]dot_.

上级 d7c352f7
/** C Implementation (with NumPy back-end) of BLAS functions used in Theano.
* Used instead of BLAS when Theano flag ``blas.ldflags`` is empty.
* This file contains some useful header code not templated.
* File alt_blas_template.c contains template code for [sd]gemm_ and [sd]gemv_. **/
* File alt_blas_template.c currently contains template code for:
* - [sd]gemm_
* - [sd]gemv_
* - [sd]dot_
**/
#define alt_fatal_error(message) { if (PyErr_Occurred()) PyErr_Print(); if(message != NULL) fprintf(stderr, message); exit(-1); }
......
......@@ -80,7 +80,7 @@ inline PyObject* alt_wrap_fortran_writeable_matrix_%(float_type)s(
return PyArray_New(&PyArray_Type, 2, dims, %(npy_float)s, strides, matrix, 0, NPY_ARRAY_WRITEABLE, NULL);
}
/* gemm_ template code */
/* gemm */
void %(precision)sgemm_(
char* TRANSA, char* TRANSB, const int* M, const int* N, const int* K,
const %(float_type)s* ALPHA, %(float_type)s* A, const int* LDA,
......@@ -203,7 +203,7 @@ void %(precision)sgemv_(
x, y are vectors
ALPHA, BETA are scalars
**/
if (*M < 0 || *N < 0 || *LDA < 0 | *incx < 0 || *incy < 0)
if (*M < 0 || *N < 0 || *LDA < 0 || *incx < 0 || *incy < 0)
alt_fatal_error("The integer arguments passed to %(precision)sgemv_ must all be at least 0.");
int transpose = alt_trans_to_bool(TRANS);
if (*M == 0 || *N == 0) {
......@@ -222,6 +222,8 @@ void %(precision)sgemv_(
matrixX = alt_op_%(float_type)s(1, x, 1, *N, *incx, 0);
matrixY = alt_op_%(float_type)s(1, y, 1, *M, *incy, NPY_ARRAY_WRITEABLE);
};
if (matrixA == NULL || matrixX == NULL || matrixY == NULL)
alt_fatal_error("NumPy %(precision)sgemv_: unable to wrap A, x or y arrays.")
if (*ALPHA == 0) {
// Just BETA * y
alt_numpy_scale_matrix_inplace_%(float_type)s(BETA, (PyArrayObject*)matrixY);
......@@ -257,3 +259,33 @@ void %(precision)sgemv_(
Py_XDECREF(matrixX);
Py_XDECREF(matrixA);
}
/* dot */
%(float_type)s %(precision)sdot_(
const int* N,
%(float_type)s *SX,
const int *INCX,
%(float_type)s *SY,
const int *INCY
) {
if (*N < 0 || *INCX < 0 || *INCY < 0)
alt_fatal_error("The integer arguments passed to %(precision)sdot_ must all be at least 0.");
// Create vector_x with shape (1, N)
PyObject* vector_x = alt_op_%(float_type)s(0, SX, 1, *N, *INCX, 0);
// Create vector_y with shape (N, 1)
PyObject* vector_y = alt_op_%(float_type)s(1, SY, 1, *N, *INCY, 0);
if (vector_x == NULL || vector_y == NULL)
alt_fatal_error("NumPy %(precision)sdot_: unables to wrap x and y arrays.");
// Make matrix product: (1, N) * (N, 1) => (1, 1)
PyObject* dot_product = PyArray_MatrixProduct(vector_x, vector_y);
if (dot_product == NULL)
alt_fatal_error("NumPy %(precision)sdot_: unables to compute dot.");
// Get result.
%(float_type)s result = *(%(float_type)s*)PyArray_DATA((PyArrayObject*)dot_product);
Py_XDECREF(dot_product);
Py_XDECREF(vector_y);
Py_XDECREF(vector_x);
return result;
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论