提交 87e10dc8 authored 作者: nouiz's avatar nouiz

Merge pull request #988 from pascanur/c_softmax

A c version for softmax almost identical to the bias softmax one
......@@ -382,6 +382,112 @@ class Softmax(gof.Op):
def infer_shape(self, node, shape):
return shape
def c_headers(self):
return ['<iostream>', '<cmath>']
@staticmethod
def c_code_template():
# this implementation was lifted from
# /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
#TODO: put this into a templated function, in the support code
#TODO: declare the max of each row as an Op output
#TODO: set error messages for failures in this code
#TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
init_decl = """
npy_intp* Nx = %(x)s->dimensions;
if (%(x)s->nd != 2)
{
PyErr_SetString(PyExc_ValueError, "a not 2d tensor");
%(fail)s;
}
if ((%(x)s->descr->type_num != PyArray_DOUBLE) &&
(%(x)s->descr->type_num != PyArray_FLOAT))
{
PyErr_SetString(PyExc_TypeError, "a not float");
%(fail)s;
}
if ((NULL == %(sm)s)
|| (%(sm)s->dimensions[0] != %(x)s->dimensions[0])
|| (%(sm)s->dimensions[1] != %(x)s->dimensions[1]))
{
if (NULL != %(sm)s) Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
type_num_%(x)s);
if(!%(sm)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output");
%(fail)s
}
}
"""
begin_row_loop = """
for (size_t i = 0; i < Nx[0]; ++i)
{
size_t j;
double sum = 0.0;
bool discount_max = false;
const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(%(x)s->data + %(x)s->strides[0] * i);
dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(%(sm)s->data + %(sm)s->strides[0] * i);
"""
inside_row_loop = """
npy_intp Sx = %(x)s->strides[1]/sizeof(dtype_%(x)s);
npy_intp Ssm = %(sm)s->strides[1]/sizeof(dtype_%(sm)s);
size_t row_max_j=0;
dtype_%(sm)s row_max = x_i[0];
//std::cout << "0 " << row_max << "\\n";
// Get the maximum value of the row
for (j = 1; j < Nx[1]; ++j)
{
dtype_%(sm)s row_ij = x_i[j * Sx] ;
//std::cout << "1 " << row_ij << "\\n";
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max;
}
for (j = 0; j < Nx[1]; ++j)
{
dtype_%(sm)s row_ij = x_i[j * Sx] ;
//std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n";
dtype_%(sm)s sm_ij = exp(row_ij - row_max);
//std::cout << "3 " << j << " " << sm_ij << "\\n";
sum += sm_ij;
sm_i[j * Ssm] = sm_ij;
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm] *= sum_inv;
}
"""
end_row_loop = """
}
"""
return (init_decl, begin_row_loop, inside_row_loop, end_row_loop)
def c_code(self, node, name, inp, out, sub):
x, = inp
sm, = out
code_template = ''.join(self.c_code_template())
return code_template % dict(locals(), **sub)
@staticmethod
def c_code_cache_version():
return (1,)
softmax = Softmax()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论