提交 03d5a93c authored 作者: Frederic Bastien's avatar Frederic Bastien

refactoring to reuse code.

上级 5da5a5ee
...@@ -252,63 +252,10 @@ class GpuImages2Neibs(Images2Neibs): ...@@ -252,63 +252,10 @@ class GpuImages2Neibs(Images2Neibs):
dtype=ten4.type.dtype)()]) dtype=ten4.type.dtype)()])
def c_code_cache_version(self): def c_code_cache_version(self):
return () return (3,)
return (2,)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
if self.mode=="valid": mode = self.mode
return """
static __global__ void k_multi_warp_%(nodename)s(
const int nb_batch,
const int nb_stack,
const int height,
const int width,
const int c,
const int d,
const int step_x,
const int step_y,
const int grid_c,
const int grid_d,
const int stride0, const int stride1, const int stride2, const int stride3,
float * global_ten4,
float * global_out
)
{
for(int tblock = blockIdx.x;tblock<nb_batch*nb_stack*grid_c*grid_d;tblock+=gridDim.x){
const int b = tblock%%grid_d;
int left = tblock/grid_d;
const int a = left%%grid_c;
left = left/grid_c;
const int s = left%%nb_stack;
left = left/nb_stack;
const int n = left;
if(n>nb_batch)continue;
if(s>nb_stack)continue;
if(a>grid_c)continue;
if(b>grid_d)continue;
int z_row = b + grid_d*(a + grid_c*(s + nb_stack*n));
for (int i = 0; i < c; i++) // loop over c
{
int ten4_2 = i + a * step_x;
for (int j = threadIdx.x; j < d; j+=blockDim.x) // loop over d
{
int ten4_3 = j + b * step_y;
//int ten4_idx = ten4_3 + width*(ten4_2 + height*(s +nb_stack*n));
//int ten4_idx = stride3*ten4_3 + stride2*(ten4_2 + stride1*(s + stride0*n));
int ten4_idx = stride3*ten4_3 + stride2*ten4_2 + stride1*s + stride0*n;
int z_col = j + d * i;
int z_idx = z_col + c*d*z_row;
global_out[z_idx] = global_ten4[ten4_idx];
}
}
}
}
""" % locals()
if self.mode=="wrap_centered":
return """ return """
static __global__ void k_multi_warp_%(nodename)s( static __global__ void k_multi_warp_%(nodename)s(
const int nb_batch, const int nb_batch,
...@@ -346,16 +293,20 @@ class GpuImages2Neibs(Images2Neibs): ...@@ -346,16 +293,20 @@ class GpuImages2Neibs(Images2Neibs):
for (int i = 0; i < c; i++) // loop over c for (int i = 0; i < c; i++) // loop over c
{ {
int ten4_2 = i + a * step_x; int ten4_2 = i + a * step_x;
if("%(mode)s"=="wrap_centered"){
ten4_2 -= wrap_centered_idx_shift_x; ten4_2 -= wrap_centered_idx_shift_x;
if ( ten4_2 < 0 ) ten4_2 += height; if ( ten4_2 < 0 ) ten4_2 += height;
else if (ten4_2 >= height) ten4_2 -= height; else if (ten4_2 >= height) ten4_2 -= height;
}
for (int j = threadIdx.x; j < d; j+=blockDim.x) // loop over d for (int j = threadIdx.x; j < d; j+=blockDim.x) // loop over d
{ {
int ten4_3 = j + b * step_y; int ten4_3 = j + b * step_y;
if("%(mode)s"=="wrap_centered"){
ten4_3 -= wrap_centered_idx_shift_y; ten4_3 -= wrap_centered_idx_shift_y;
if ( ten4_3 < 0 ) ten4_3 += width; if ( ten4_3 < 0 ) ten4_3 += width;
else if (ten4_3 >= width) ten4_3 -= width; else if (ten4_3 >= width) ten4_3 -= width;
}
//int ten4_idx = ten4_3 + width*(ten4_2 + height*(s +nb_stack*n)); //int ten4_idx = ten4_3 + width*(ten4_2 + height*(s +nb_stack*n));
//int ten4_idx = stride3*ten4_3 + stride2*(ten4_2 + stride1*(s + stride0*n)); //int ten4_idx = stride3*ten4_3 + stride2*(ten4_2 + stride1*(s + stride0*n));
int ten4_idx = stride3*ten4_3 + stride2*ten4_2 + stride1*s + stride0*n; int ten4_idx = stride3*ten4_3 + stride2*ten4_2 + stride1*s + stride0*n;
...@@ -370,7 +321,6 @@ class GpuImages2Neibs(Images2Neibs): ...@@ -370,7 +321,6 @@ class GpuImages2Neibs(Images2Neibs):
""" % locals() """ % locals()
def c_code(self, node, name, (ten4, neib_shape, neib_step), (z,), sub): def c_code(self, node, name, (ten4, neib_shape, neib_step), (z,), sub):
fail = sub['fail'] fail = sub['fail']
mode = self.mode mode = self.mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论