提交 95bf6e34 authored 作者: abergeron's avatar abergeron

Merge pull request #1774 from nouiz/softmax

Softmax optimization
...@@ -909,7 +909,22 @@ class UnaryScalarOp(ScalarOp): ...@@ -909,7 +909,22 @@ class UnaryScalarOp(ScalarOp):
node.inputs[0].type != node.outputs[0].type): node.inputs[0].type != node.outputs[0].type):
raise theano.gof.utils.MethodNotDefined() raise theano.gof.utils.MethodNotDefined()
dtype = node.inputs[0].dtype dtype = node.inputs[0].type.dtype_specs()[1]
fct_call = self.c_code_contiguous_raw(dtype, 'n', 'x', 'z')
return """
{
npy_intp n = PyArray_SIZE(%(z)s);
%(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
%(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
%(fct_call)s;
}
""" % locals()
def c_code_contiguous_raw(self, dtype, n, i, o):
if not config.lib.amdlibm:
raise theano.gof.utils.MethodNotDefined()
if dtype.startswith('npy_'):
dtype = dtype[4:]
if dtype == 'float32' and self.amd_float32 is not None: if dtype == 'float32' and self.amd_float32 is not None:
dtype = 'float' dtype = 'float'
fct = self.amd_float32 fct = self.amd_float32
...@@ -918,12 +933,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -918,12 +933,7 @@ class UnaryScalarOp(ScalarOp):
fct = self.amd_float64 fct = self.amd_float64
else: else:
raise theano.gof.utils.MethodNotDefined() raise theano.gof.utils.MethodNotDefined()
return """ return "%(fct)s(%(n)s, %(i)s, %(o)s)" % locals()
npy_intp n = PyArray_SIZE(%(z)s);
%(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
%(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
%(fct)s(n, x, z);
""" % locals()
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
......
...@@ -95,7 +95,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -95,7 +95,7 @@ class SoftmaxWithBias(gof.Op):
return ['<iostream>', '<cmath>'] return ['<iostream>', '<cmath>']
@staticmethod @staticmethod
def c_code_template(): def c_code_template(dtype):
# this implementation was lifted from # this implementation was lifted from
# /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
...@@ -107,6 +107,10 @@ class SoftmaxWithBias(gof.Op): ...@@ -107,6 +107,10 @@ class SoftmaxWithBias(gof.Op):
#TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
init_decl = """ init_decl = """
npy_intp* Nx = PyArray_DIMS(%(x)s); npy_intp* Nx = PyArray_DIMS(%(x)s);
npy_intp Sx = 0;
npy_intp Sb = 0;
npy_intp Ssm = 0;
if (PyArray_NDIM(%(x)s) != 2) if (PyArray_NDIM(%(x)s) != 2)
{ {
...@@ -151,6 +155,10 @@ class SoftmaxWithBias(gof.Op): ...@@ -151,6 +155,10 @@ class SoftmaxWithBias(gof.Op):
%(fail)s %(fail)s
} }
} }
Sx = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
Sb = PyArray_STRIDES(%(b)s)[0]/sizeof(dtype_%(b)s);
Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
""" """
begin_row_loop = """ begin_row_loop = """
...@@ -163,9 +171,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -163,9 +171,7 @@ class SoftmaxWithBias(gof.Op):
const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i); const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i);
const dtype_%(b)s* __restrict__ b_i = (dtype_%(b)s*)(PyArray_BYTES(%(b)s)); const dtype_%(b)s* __restrict__ b_i = (dtype_%(b)s*)(PyArray_BYTES(%(b)s));
dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i); dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i);
"""
inside_row_loop = """
npy_intp Sx = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s); npy_intp Sx = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
npy_intp Sb = PyArray_STRIDES(%(b)s)[0]/sizeof(dtype_%(b)s); npy_intp Sb = PyArray_STRIDES(%(b)s)[0]/sizeof(dtype_%(b)s);
npy_intp Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s); npy_intp Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
...@@ -182,6 +188,9 @@ class SoftmaxWithBias(gof.Op): ...@@ -182,6 +188,9 @@ class SoftmaxWithBias(gof.Op):
row_max = (row_ij > row_max) ? row_ij : row_max; row_max = (row_ij > row_max) ? row_ij : row_max;
} }
"""
inside_row_loop = """
for (j = 0; j < Nx[1]; ++j) for (j = 0; j < Nx[1]; ++j)
{ {
dtype_%(sm)s row_ij = x_i[j * Sx] + b_i[j * Sb]; dtype_%(sm)s row_ij = x_i[j * Sx] + b_i[j * Sb];
...@@ -201,6 +210,42 @@ class SoftmaxWithBias(gof.Op): ...@@ -201,6 +210,42 @@ class SoftmaxWithBias(gof.Op):
""" """
# Get the vectorized version of exp if it exist
try:
vec_exp = theano.scalar.exp.c_code_contiguous_raw(dtype,
"Nx[1]", "sm_i", "sm_i")
inside_row_loop_contig = """
for (j = 0; j < Nx[1]; ++j)
{
dtype_%%(sm)s row_ij = x_i[j * Sx] + b_i[j * Sb];
//std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n";
dtype_%%(sm)s sm_ij = row_ij - row_max;
//std::cout << "3 " << j << " " << sm_ij << "\\n";
sm_i[j * Ssm] = sm_ij;
}
%(vec_exp)s;
for (j = 0; j < Nx[1]; ++j)
{
sum += sm_i[j * Ssm];
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm] *= sum_inv;
}
""" % locals()
inside_row_loop = """
if(Ssm == 1){
%(inside_row_loop_contig)s
}else{
%(inside_row_loop)s
}
""" % locals()
except theano.gof.utils.MethodNotDefined:
pass
end_row_loop = """ end_row_loop = """
} }
""" """
...@@ -210,12 +255,13 @@ class SoftmaxWithBias(gof.Op): ...@@ -210,12 +255,13 @@ class SoftmaxWithBias(gof.Op):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, b = inp x, b = inp
sm, = out sm, = out
code_template = ''.join(self.c_code_template()) code_template = ''.join(self.c_code_template(
node.inputs[0].type.dtype_specs()[1]))
return code_template % dict(locals(), **sub) return code_template % dict(locals(), **sub)
@staticmethod @staticmethod
def c_code_cache_version(): def c_code_cache_version():
return (6,) return (8,)
softmax_with_bias = SoftmaxWithBias() softmax_with_bias = SoftmaxWithBias()
...@@ -384,7 +430,7 @@ class Softmax(gof.Op): ...@@ -384,7 +430,7 @@ class Softmax(gof.Op):
return ['<iostream>', '<cmath>'] return ['<iostream>', '<cmath>']
@staticmethod @staticmethod
def c_code_template(): def c_code_template(dtype):
# this implementation was lifted from # this implementation was lifted from
# /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
...@@ -396,6 +442,8 @@ class Softmax(gof.Op): ...@@ -396,6 +442,8 @@ class Softmax(gof.Op):
#TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
init_decl = """ init_decl = """
npy_intp* Nx = PyArray_DIMS(%(x)s); npy_intp* Nx = PyArray_DIMS(%(x)s);
npy_intp Sx1 = 0;
npy_intp Ssm1 = 0;
if (PyArray_NDIM(%(x)s) != 2) if (PyArray_NDIM(%(x)s) != 2)
{ {
...@@ -413,7 +461,7 @@ class Softmax(gof.Op): ...@@ -413,7 +461,7 @@ class Softmax(gof.Op):
|| (PyArray_DIMS(%(sm)s)[0] != PyArray_DIMS(%(x)s)[0]) || (PyArray_DIMS(%(sm)s)[0] != PyArray_DIMS(%(x)s)[0])
|| (PyArray_DIMS(%(sm)s)[1] != PyArray_DIMS(%(x)s)[1])) || (PyArray_DIMS(%(sm)s)[1] != PyArray_DIMS(%(x)s)[1]))
{ {
if (NULL != %(sm)s) Py_XDECREF(%(sm)s); Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
type_num_%(x)s); type_num_%(x)s);
if(!%(sm)s) { if(!%(sm)s) {
...@@ -422,6 +470,8 @@ class Softmax(gof.Op): ...@@ -422,6 +470,8 @@ class Softmax(gof.Op):
%(fail)s %(fail)s
} }
} }
Sx1 = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
Ssm1 = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
""" """
begin_row_loop = """ begin_row_loop = """
...@@ -433,11 +483,6 @@ class Softmax(gof.Op): ...@@ -433,11 +483,6 @@ class Softmax(gof.Op):
const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i); const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i);
dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i); dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i);
"""
inside_row_loop = """
npy_intp Sx = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
npy_intp Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
size_t row_max_j=0; size_t row_max_j=0;
dtype_%(sm)s row_max = x_i[0]; dtype_%(sm)s row_max = x_i[0];
...@@ -445,46 +490,82 @@ class Softmax(gof.Op): ...@@ -445,46 +490,82 @@ class Softmax(gof.Op):
// Get the maximum value of the row // Get the maximum value of the row
for (j = 1; j < Nx[1]; ++j) for (j = 1; j < Nx[1]; ++j)
{ {
dtype_%(sm)s row_ij = x_i[j * Sx] ; dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "1 " << row_ij << "\\n"; //std::cout << "1 " << row_ij << "\\n";
row_max_j = (row_ij > row_max) ? j : row_max_j; row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max; row_max = (row_ij > row_max) ? row_ij : row_max;
} }
"""
inside_row_loop = """
for (j = 0; j < Nx[1]; ++j) for (j = 0; j < Nx[1]; ++j)
{ {
dtype_%(sm)s row_ij = x_i[j * Sx] ; dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n"; //std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n";
dtype_%(sm)s sm_ij = exp(row_ij - row_max); dtype_%(sm)s sm_ij = exp(row_ij - row_max);
//std::cout << "3 " << j << " " << sm_ij << "\\n"; //std::cout << "3 " << j << " " << sm_ij << "\\n";
sum += sm_ij; sum += sm_ij;
sm_i[j * Ssm] = sm_ij; sm_i[j * Ssm1] = sm_ij;
} }
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n); //cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum; double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j) for (j = 0; j < Nx[1]; ++j)
{ {
sm_i[j * Ssm] *= sum_inv; sm_i[j * Ssm1] *= sum_inv;
} }
""" """
# Get the vectorized version of exp if it exist
try:
vec_exp = theano.scalar.exp.c_code_contiguous_raw(dtype,
"Nx[1]", "sm_i", "sm_i")
inside_row_loop_contig = """
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] = x_i[j * Sx1] - row_max;
}
%(vec_exp)s;
for (j = 0; j < Nx[1]; ++j)
{
sum += sm_i[j * Ssm1];
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] *= sum_inv;
}
""" % locals()
inside_row_loop = """
if(Ssm1 == 1){
%(inside_row_loop_contig)s
}else{
%(inside_row_loop)s
}
""" % locals()
except theano.gof.utils.MethodNotDefined:
pass
end_row_loop = """ end_row_loop = """
} }
""" """
return (init_decl, begin_row_loop, inside_row_loop, end_row_loop) return (init_decl, begin_row_loop, inside_row_loop, end_row_loop)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
sm, = out sm, = out
code_template = ''.join(self.c_code_template()) code_template = ''.join(self.c_code_template(
node.inputs[0].type.dtype_specs()[1]))
return code_template % dict(locals(), **sub) return code_template % dict(locals(), **sub)
@staticmethod @staticmethod
def c_code_cache_version(): def c_code_cache_version():
return (1,) return (3,)
softmax = Softmax() softmax = Softmax()
...@@ -863,7 +944,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -863,7 +944,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
return ['<iostream>', '<cmath>'] return ['<iostream>', '<cmath>']
@staticmethod @staticmethod
def c_code_template(): def c_code_template(dtype):
# this implementation was lifted from # this implementation was lifted from
# /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
...@@ -874,7 +955,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -874,7 +955,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
#TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
(init_decl, begin_row_loop, inside_row_loop, end_row_loop) = \ (init_decl, begin_row_loop, inside_row_loop, end_row_loop) = \
SoftmaxWithBias.c_code_template() SoftmaxWithBias.c_code_template(dtype)
return (init_decl, return (init_decl,
""" """
if (PyArray_NDIM(%(y_idx)s) != 1) if (PyArray_NDIM(%(y_idx)s) != 1)
...@@ -947,7 +1028,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -947,7 +1028,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
nll, sm, am = out nll, sm, am = out
y_idx_type = node.inputs[2].type.dtype_specs()[1] y_idx_type = node.inputs[2].type.dtype_specs()[1]
am_type = y_idx_type am_type = y_idx_type
code_template = ''.join(self.c_code_template()) dtype = node.inputs[0].type.dtype_specs()[1]
code_template = ''.join(self.c_code_template(dtype))
return code_template % dict(locals(), **sub) return code_template % dict(locals(), **sub)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论