提交 488cef77 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make GpuImages2Neibs work with strided output

上级 b5ef1599
...@@ -314,7 +314,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -314,7 +314,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
dtype=ten4.type.dtype)()]) dtype=ten4.type.dtype)()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
mode = self.mode mode = self.mode
...@@ -333,6 +333,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -333,6 +333,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
const int grid_d, const int grid_d,
const int stride0, const int stride1, const int stride2, const int stride3, const int stride0, const int stride1, const int stride2, const int stride3,
float * global_ten4, float * global_ten4,
const int out_s0, const int out_s1,
float * global_out float * global_out
) )
{ {
...@@ -375,7 +376,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -375,7 +376,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
int ten4_idx = stride3*ten4_3 + stride2*ten4_2 + stride1*s + stride0*n; int ten4_idx = stride3*ten4_3 + stride2*ten4_2 + stride1*s + stride0*n;
int z_col = j + d * i; int z_col = j + d * i;
int z_idx = z_col + c*d*z_row; int z_idx = z_col * out_s1 + z_row * out_s0;
global_out[z_idx] = global_ten4[ten4_idx]; global_out[z_idx] = global_ten4[ten4_idx];
} }
} }
...@@ -395,6 +396,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -395,6 +396,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
const int grid_d, const int grid_d,
const int stride0, const int stride1, const int stride2, const int stride3, const int stride0, const int stride1, const int stride2, const int stride3,
float * global_ten4, float * global_ten4,
const int out_s0, const int out_s1,
float * global_out float * global_out
) )
{ {
...@@ -437,7 +439,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -437,7 +439,7 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
int ten4_idx = stride3*ten4_3 + stride2*ten4_2 + stride1*s + stride0*n; int ten4_idx = stride3*ten4_3 + stride2*ten4_2 + stride1*s + stride0*n;
int z_col = j + d * i; int z_col = j + d * i;
int z_idx = z_col + c*d*z_row; int z_idx = z_col * out_s1 + z_row * out_s0;
global_out[z_idx] = global_ten4[ten4_idx]; global_out[z_idx] = global_ten4[ten4_idx];
} }
} }
...@@ -573,7 +575,9 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -573,7 +575,9 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
int, int, int ,int, int, int, int ,int,
int, int, int, int,
int, int, int, int, int, int, int, int,
float*, float*); float*,
int, int,
float*);
if(n_threads.x==d && n_threads.y==c){ if(n_threads.x==d && n_threads.y==c){
f = k_multi_warp_less_%(name)s; f = k_multi_warp_less_%(name)s;
}else{ }else{
...@@ -591,6 +595,8 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -591,6 +595,8 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
CudaNdarray_HOST_STRIDES(%(ten4)s)[2], CudaNdarray_HOST_STRIDES(%(ten4)s)[2],
CudaNdarray_HOST_STRIDES(%(ten4)s)[3], CudaNdarray_HOST_STRIDES(%(ten4)s)[3],
CudaNdarray_DEV_DATA(%(ten4)s), CudaNdarray_DEV_DATA(%(ten4)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1],
CudaNdarray_DEV_DATA(%(z)s) CudaNdarray_DEV_DATA(%(z)s)
); );
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论