提交 51950421 authored 作者: Frederic's avatar Frederic

Move the access to strides outside of the loop.

上级 4b7ca419
...@@ -396,6 +396,8 @@ class Softmax(gof.Op): ...@@ -396,6 +396,8 @@ class Softmax(gof.Op):
#TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
init_decl = """ init_decl = """
npy_intp* Nx = PyArray_DIMS(%(x)s); npy_intp* Nx = PyArray_DIMS(%(x)s);
npy_intp Sx1 = 0;
npy_intp Ssm1 = 0;
if (PyArray_NDIM(%(x)s) != 2) if (PyArray_NDIM(%(x)s) != 2)
{ {
...@@ -422,6 +424,8 @@ class Softmax(gof.Op): ...@@ -422,6 +424,8 @@ class Softmax(gof.Op):
%(fail)s %(fail)s
} }
} }
Sx1 = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
Ssm1 = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
""" """
begin_row_loop = """ begin_row_loop = """
...@@ -436,16 +440,13 @@ class Softmax(gof.Op): ...@@ -436,16 +440,13 @@ class Softmax(gof.Op):
""" """
inside_row_loop = """ inside_row_loop = """
npy_intp Sx = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
npy_intp Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
size_t row_max_j=0; size_t row_max_j=0;
dtype_%(sm)s row_max = x_i[0]; dtype_%(sm)s row_max = x_i[0];
//std::cout << "0 " << row_max << "\\n"; //std::cout << "0 " << row_max << "\\n";
// Get the maximum value of the row // Get the maximum value of the row
for (j = 1; j < Nx[1]; ++j) for (j = 1; j < Nx[1]; ++j)
{ {
dtype_%(sm)s row_ij = x_i[j * Sx] ; dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "1 " << row_ij << "\\n"; //std::cout << "1 " << row_ij << "\\n";
row_max_j = (row_ij > row_max) ? j : row_max_j; row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max; row_max = (row_ij > row_max) ? row_ij : row_max;
...@@ -453,19 +454,19 @@ class Softmax(gof.Op): ...@@ -453,19 +454,19 @@ class Softmax(gof.Op):
for (j = 0; j < Nx[1]; ++j) for (j = 0; j < Nx[1]; ++j)
{ {
dtype_%(sm)s row_ij = x_i[j * Sx] ; dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n"; //std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n";
dtype_%(sm)s sm_ij = exp(row_ij - row_max); dtype_%(sm)s sm_ij = exp(row_ij - row_max);
//std::cout << "3 " << j << " " << sm_ij << "\\n"; //std::cout << "3 " << j << " " << sm_ij << "\\n";
sum += sm_ij; sum += sm_ij;
sm_i[j * Ssm] = sm_ij; sm_i[j * Ssm1] = sm_ij;
} }
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n); //cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum; double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j) for (j = 0; j < Nx[1]; ++j)
{ {
sm_i[j * Ssm] *= sum_inv; sm_i[j * Ssm1] *= sum_inv;
} }
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论