提交 39a2b2d2 authored 作者: notoraptor's avatar notoraptor

Added some simplifications.

上级 0cd7aa7b
......@@ -3,19 +3,17 @@ C Implementation of dgemm_ based on NumPy
Used instead of blas when Theano config flag blas.ldflags is empty.
**/
void alt_double_scalar_matrix_product_in_place(double scalar, double* matrix, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
for(int i = 0; i < size_to_compute; ++i) {
matrix[i] *= scalar;
}
}
void alt_double_matrix_sum_in_place(double* A, double* B, double* out, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
for(int i = 0; i < size_to_compute; ++i) {
out[i] = A[i] + B[i];
}
}
/* dgemm
* NB: See sgemm_ for same assumptions.
* NB: See sgemm_ (in alt_sgemm.c) for same assumptions.
* */
void dgemm_(char* TRANSA, char* TRANSB,
const int* M, const int* N, const int* K,
......
......@@ -12,14 +12,12 @@ inline PyObject* alt_matrix_matrix_product2(PyObject* o1, PyObject* o2, PyObject
return PyArray_MatrixProduct2(o1, o2, (PyArrayObject*)out);
}
void alt_scalar_matrix_product_in_place(float scalar, float* matrix, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
for(int i = 0; i < size_to_compute; ++i) {
matrix[i] *= scalar;
}
}
void alt_matrix_sum_in_place(float* A, float* B, float* out, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
for(int i = 0; i < size_to_compute; ++i) {
out[i] = A[i] + B[i];
}
}
......@@ -27,8 +25,8 @@ inline PyObject* alt_op(char* trans, PyArrayObject* matrix) {
return (*trans == 'N' || *trans == 'n') ? (PyObject*)matrix : alt_transpose(matrix);
}
/* sgemm
* We assume that none of these 13 pointers passed as arguments are null.
* NB: We can optimize this function again (for example, when alpha == 0 and/or beta == 0).
* We assume that none of these 13 pointers passed as arguments is null.
* NB: We can more optimize this function (for example, when alpha == 0).
* */
void sgemm_(char* TRANSA, char* TRANSB,
const int* M, const int* N, const int* K,
......@@ -40,9 +38,9 @@ void sgemm_(char* TRANSA, char* TRANSB,
if(C == NULL)
return;
/* Recall:
A is a *LDA by *ka matrix.
B is a *LDB by *kb matrix.
C is a *LDC By *N matrix.
* A is a *LDA by *ka matrix.
* B is a *LDB by *kb matrix.
* C is a *LDC By *N matrix.
*/
int ka, kb;
if(*TRANSA == 'N' || *TRANSA == 'n')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论