提交 9d73a2d5 authored 作者: Frederic's avatar Frederic

Use vectorized version of exp in the Softmax op.

上级 d3f724ab
...@@ -384,7 +384,7 @@ class Softmax(gof.Op): ...@@ -384,7 +384,7 @@ class Softmax(gof.Op):
return ['<iostream>', '<cmath>'] return ['<iostream>', '<cmath>']
@staticmethod @staticmethod
def c_code_template(): def c_code_template(dtype):
# this implementation was lifted from # this implementation was lifted from
# /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
...@@ -470,22 +470,67 @@ class Softmax(gof.Op): ...@@ -470,22 +470,67 @@ class Softmax(gof.Op):
} }
""" """
# Get the vectorized version of exp if it exist
try:
vec_exp = theano.scalar.exp.c_code_contiguous_raw(dtype,
"Nx[1]", "sm_i", "sm_i")
inside_row_loop_contig = """
size_t row_max_j=0;
dtype_%%(sm)s row_max = x_i[0];
//std::cout << "0 " << row_max << "\\n";
// Get the maximum value of the row
for (j = 1; j < Nx[1]; ++j)
{
dtype_%%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "1 " << row_ij << "\\n";
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max;
}
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] = x_i[j * Sx1] - row_max;
}
%(vec_exp)s;
for (j = 0; j < Nx[1]; ++j)
{
sum += sm_i[j * Ssm1];
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] *= sum_inv;
}
""" % locals()
inside_row_loop = """
if(Ssm1 == 1){
%(inside_row_loop_contig)s
}else{
%(inside_row_loop)s
}
""" % locals()
except theano.gof.utils.MethodNotDefined:
pass
end_row_loop = """ end_row_loop = """
} }
""" """
return (init_decl, begin_row_loop, inside_row_loop, end_row_loop) return (init_decl, begin_row_loop, inside_row_loop, end_row_loop)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
sm, = out sm, = out
code_template = ''.join(self.c_code_template()) code_template = ''.join(self.c_code_template(
node.inputs[0].type.dtype_specs()[1]))
return code_template % dict(locals(), **sub) return code_template % dict(locals(), **sub)
@staticmethod @staticmethod
def c_code_cache_version(): def c_code_cache_version():
return (1,) return (2,)
softmax = Softmax() softmax = Softmax()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论