提交 1c06a20a authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

added missing blas dot operation

上级 6d7cb0c5
......@@ -742,10 +742,10 @@ class SamplingDotCsr(gof.Op):
if dot_out == "float32":
conv_type = "float"
cdot = "sdot_sub_"
cdot = "sdot_"
else:
conv_type = "double"
cdot = "ddot_sub_"
cdot = "ddot_"
# retrieve dtype number
typenum_x = node.inputs[0].type.dtype_specs()[-1]
......@@ -839,9 +839,7 @@ class SamplingDotCsr(gof.Op):
const dtype_%(y)s* y_col = (dtype_%(y)s*)(%(y)s->data + %(y)s->strides[0] * n);
%(cdot)s((int*)&K, (const %(conv_type)s*)x_row, (int*)&Sdx, (const %(conv_type)s*)y_col, (int*)&Sdy, &Dzd[n_idx * Sdzd]);
Dzd[n_idx * Sdzd] *= Dpd[n_idx * Sdpd];
Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s((int*)&K, (const %(conv_type)s*)x_row, (int*)&Sdx, (const %(conv_type)s*)y_col, (int*)&Sdy);
}
}
}
......
......@@ -604,6 +604,7 @@ def blas_header_text():
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
float sdot_(const int*, const float *, const int*, const float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
......@@ -621,6 +622,7 @@ def blas_header_text():
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
double ddot_(const int*, const double *, const int*, const double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论