提交 925eca70 authored 作者: Frederic Bastien's avatar Frederic Bastien

made the ConvOp gemm implementation work with float.

上级 e73c5328
......@@ -88,7 +88,6 @@ class ConvOp(Op):
def make_node(self, inputs, kerns):
# TODO: find a way to make ConvOp work for N-D (after NIPS09)
outdim = kerns.ndim
print inputs.type.dtype, kerns.type.dtype
if inputs.type.dtype != kerns.type.dtype:
raise Exception("The image and the kernel must have the same type."
"inputs(%s), kerns(%s)"%(inputs.dtype, kerns.dtype))
......@@ -229,6 +228,8 @@ using namespace std;
if node.inputs[0].type.dtype=="float32": d["type"]="float"
elif node.inputs[0].type.dtype=="float64": d["type"]="double"
else: raise Exception("Type %s not implemented"%node.inputs[0].type.dtype)
d["gemm"]='dgemm_' if d["type"]=="double" else 'sgemm_'
if self.unroll_batch>0 or self.unroll_kern>0:
if self.unroll_batch<=0: self.unroll_batch=1
if self.unroll_kern<=0: self.unroll_kern=1
......@@ -238,8 +239,10 @@ using namespace std;
#TODO: should we choose the unroll size automatically with the bigger divisor under 5?
if self.out_mode == 'valid':
# print "return gemm version"
return _conv_op_code_valid_gemm % d
else:
# print "return no gemm version"
return _conv_op_code_a % d
def convolve2(kerns, kshp, nkern, images, imshp, bsize, step=(1,1),
......@@ -630,12 +633,12 @@ for(int b=0;b< %(self_bsize)s;b++){
int Nz0 = Os[0];
int Nz1 = NKERN;
int K = kerns_dim[3];
double alpha = 1.0;
double beta = stackidx ? 1.0 : 0.0;
%(type)s alpha = 1.0;
%(type)s beta = stackidx ? 1.0 : 0.0;
int imgview_stride = dim_im[1];
int filter_rows_stride =kerns_dim[1]*kerns_dim[2]*kerns_dim[3];
//remember, Fortran wants a column-major interpretation
assert(img2d->strides[3] == sizeof(double));
assert(img2d->strides[3] == sizeof(%(type)s));
if (0){
std::cerr << "b " << b << " img_col " << img_col << " filterrow " << filter_row << " stackidx " <<stackidx << "\\n";
......@@ -657,7 +660,7 @@ for(int b=0;b< %(self_bsize)s;b++){
std::cerr << Nz1 << " " << Nz0 << " " << K << "\\n" ;
}
dgemm_(&T, &N,
%(gemm)s(&T, &N,
&Nz1, &Nz0, &K,
&alpha,
filter_rows, &filter_rows_stride,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论