提交 0cff557b authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6041 from xiaoqie/pool-fix2

Fix parameter types in ave_pool kernels, remove static_cast, add GLOBAL_MEM
...@@ -1632,7 +1632,7 @@ class GpuEye(GpuKernelBase, Op): ...@@ -1632,7 +1632,7 @@ class GpuEye(GpuKernelBase, Op):
code = """ code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size a_off, KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size a_off,
ga_size n, ga_size m, ga_ssize k) { ga_size n, ga_size m, ga_ssize k) {
a = (GLOBAL_MEM %(ctype)s *)(((char *)a) + a_off); a = (GLOBAL_MEM %(ctype)s *)(((GLOBAL_MEM char *)a) + a_off);
ga_ssize coff = max(k, (ga_ssize) 0); ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0); ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff); ga_size nb = (ga_size) min(n - roff, m - coff);
...@@ -1706,4 +1706,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size a_off, ...@@ -1706,4 +1706,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size a_off,
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (9,)
...@@ -47,8 +47,8 @@ KERNEL void dilated_im3d2col_kernel(const ga_size n, ...@@ -47,8 +47,8 @@ KERNEL void dilated_im3d2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col, const ga_size depth_col, const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col, GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) { const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -100,8 +100,8 @@ KERNEL void im3d2col_kernel(const ga_size n, ...@@ -100,8 +100,8 @@ KERNEL void im3d2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col, const ga_size depth_col, const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col, GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) { const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -115,10 +115,10 @@ KERNEL void im3d2col_kernel(const ga_size n, ...@@ -115,10 +115,10 @@ KERNEL void im3d2col_kernel(const ga_size n,
const ga_size h_offset = h_col * stride_h - pad_h; const ga_size h_offset = h_col * stride_h - pad_h;
const ga_size w_offset = w_col * stride_w - pad_w; const ga_size w_offset = w_col * stride_w - pad_w;
const ga_size d_offset = d_col * stride_d - pad_d; const ga_size d_offset = d_col * stride_d - pad_d;
DTYPE_INPUT_0 * data_col_ptr = data_col; GLOBAL_MEM DTYPE_INPUT_0 * data_col_ptr = data_col;
data_col_ptr += c_col * (height_col * width_col * depth_col) + data_col_ptr += c_col * (height_col * width_col * depth_col) +
h_col * (width_col * depth_col) + w_col * depth_col + d_col; h_col * (width_col * depth_col) + w_col * depth_col + d_col;
const DTYPE_INPUT_0 * data_im_ptr = data_im + data_im_offset; GLOBAL_MEM const DTYPE_INPUT_0 * data_im_ptr = data_im + data_im_offset;
data_im_ptr += c_im * (height * width * depth) + data_im_ptr += c_im * (height * width * depth) +
h_offset * (width * depth) + w_offset * depth + d_offset; h_offset * (width * depth) + w_offset * depth + d_offset;
for (ga_size i = 0; i < kernel_h; ++i) { for (ga_size i = 0; i < kernel_h; ++i) {
...@@ -154,8 +154,8 @@ KERNEL void dilated_col2im3d_kernel(const ga_size n, ...@@ -154,8 +154,8 @@ KERNEL void dilated_col2im3d_kernel(const ga_size n,
const ga_size data_im_offset) { const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im. // offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array // data_im_offset is an offset of elements in the array
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -222,8 +222,8 @@ KERNEL void col2im3d_kernel(const ga_size n, ...@@ -222,8 +222,8 @@ KERNEL void col2im3d_kernel(const ga_size n,
const ga_size data_im_offset) { const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im. // offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array // data_im_offset is an offset of elements in the array
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
......
...@@ -47,8 +47,8 @@ KERNEL void dilated_im2col_kernel(const ga_size n, ...@@ -47,8 +47,8 @@ KERNEL void dilated_im2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col, const ga_size height_col, const ga_size width_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col, GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) { const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -91,8 +91,8 @@ KERNEL void im2col_kernel(const ga_size n, ...@@ -91,8 +91,8 @@ KERNEL void im2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col, const ga_size height_col, const ga_size width_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col, GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) { const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -135,8 +135,8 @@ KERNEL void dilated_col2im_kernel(const ga_size n, ...@@ -135,8 +135,8 @@ KERNEL void dilated_col2im_kernel(const ga_size n,
const ga_size data_im_offset) { const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im. // offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array // data_im_offset is an offset of elements in the array
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -185,8 +185,8 @@ KERNEL void col2im_kernel(const ga_size n, ...@@ -185,8 +185,8 @@ KERNEL void col2im_kernel(const ga_size n,
const ga_size data_im_offset) { const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im. // offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array // data_im_offset is an offset of elements in the array
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col); data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im); data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
......
...@@ -84,9 +84,9 @@ KERNEL void k_multi_warp_multinomial( ...@@ -84,9 +84,9 @@ KERNEL void k_multi_warp_multinomial(
const ga_ssize outs_col_stride const ga_ssize outs_col_stride
) )
{ {
global_pvals = (GLOBAL_MEM %(in_ctype)s *)(((char *)global_pvals) + global_pvals_offset); global_pvals = (GLOBAL_MEM %(in_ctype)s *)(((GLOBAL_MEM char *)global_pvals) + global_pvals_offset);
global_unis = (GLOBAL_MEM %(in_ctype)s *)(((char *)global_unis) + global_unis_offset); global_unis = (GLOBAL_MEM %(in_ctype)s *)(((GLOBAL_MEM char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM %(out_ctype)s *)(((char *)global_outs) + global_outs_offset); global_outs = (GLOBAL_MEM %(out_ctype)s *)(((GLOBAL_MEM char *)global_outs) + global_outs_offset);
// each thread takes care of one multinomial draw // each thread takes care of one multinomial draw
int n = LDIM_0*GID_0 + LID_0; int n = LDIM_0*GID_0 + LID_0;
if (n < nb_multi) if (n < nb_multi)
...@@ -220,7 +220,7 @@ KERNEL void k_multi_warp_multinomial( ...@@ -220,7 +220,7 @@ KERNEL void k_multi_warp_multinomial(
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (5,)
class GPUAChoiceFromUniform(GpuKernelBase, Op): class GPUAChoiceFromUniform(GpuKernelBase, Op):
...@@ -297,9 +297,9 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -297,9 +297,9 @@ KERNEL void k_multi_warp_multinomial_wor(
const ga_ssize outs_col_stride const ga_ssize outs_col_stride
) )
{ {
global_pvals_copy = (GLOBAL_MEM float *)(((char *)global_pvals_copy) + global_pvals_offset); global_pvals_copy = (GLOBAL_MEM float *)(((GLOBAL_MEM char *)global_pvals_copy) + global_pvals_offset);
global_unis = (GLOBAL_MEM float *)(((char *)global_unis) + global_unis_offset); global_unis = (GLOBAL_MEM float *)(((GLOBAL_MEM char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM ga_long *)(((char *)global_outs) + global_outs_offset); global_outs = (GLOBAL_MEM ga_long *)(((GLOBAL_MEM char *)global_outs) + global_outs_offset);
// each thread takes care of one multinomial-wor n_samples-draw // each thread takes care of one multinomial-wor n_samples-draw
int n = LDIM_0*GID_0 + LID_0; int n = LDIM_0*GID_0 + LID_0;
...@@ -455,7 +455,7 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -455,7 +455,7 @@ KERNEL void k_multi_warp_multinomial_wor(
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (9,)
@register_opt('fast_compile') @register_opt('fast_compile')
......
...@@ -10,8 +10,8 @@ KERNEL void max_pool2d_kernel(const ga_size nthreads, ...@@ -10,8 +10,8 @@ KERNEL void max_pool2d_kernel(const ga_size nthreads,
const ga_size stride_h, const ga_size stride_w, const ga_size pad_h, const ga_size pad_w, const ga_size stride_h, const ga_size stride_w, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off) GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{ {
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index < nthreads;
...@@ -55,8 +55,8 @@ KERNEL void max_pool3d_kernel(const ga_size nthreads, ...@@ -55,8 +55,8 @@ KERNEL void max_pool3d_kernel(const ga_size nthreads,
const ga_size stride_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w, const ga_size stride_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off) GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{ {
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index < nthreads;
...@@ -94,7 +94,7 @@ KERNEL void max_pool3d_kernel(const ga_size nthreads, ...@@ -94,7 +94,7 @@ KERNEL void max_pool3d_kernel(const ga_size nthreads,
} }
} }
#kernel ave_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, *, size: #kernel ave_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, bool, bool, *, size:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu) // (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool2d_kernel(const ga_size nthreads, KERNEL void ave_pool2d_kernel(const ga_size nthreads,
...@@ -105,8 +105,8 @@ KERNEL void ave_pool2d_kernel(const ga_size nthreads, ...@@ -105,8 +105,8 @@ KERNEL void ave_pool2d_kernel(const ga_size nthreads,
const ga_bool inc_pad, const ga_bool sum_mode, const ga_bool inc_pad, const ga_bool sum_mode,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off) GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{ {
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index < nthreads;
...@@ -149,7 +149,7 @@ KERNEL void ave_pool2d_kernel(const ga_size nthreads, ...@@ -149,7 +149,7 @@ KERNEL void ave_pool2d_kernel(const ga_size nthreads,
} }
} }
#kernel ave_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, size, size, size, *, size : #kernel ave_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, size, bool, bool, *, size :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu) // (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool3d_kernel(const ga_size nthreads, KERNEL void ave_pool3d_kernel(const ga_size nthreads,
...@@ -163,8 +163,8 @@ KERNEL void ave_pool3d_kernel(const ga_size nthreads, ...@@ -163,8 +163,8 @@ KERNEL void ave_pool3d_kernel(const ga_size nthreads,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off) GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{ {
// grid stride looping // grid stride looping
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index < nthreads;
index += LDIM_0 * GDIM_0) { index += LDIM_0 * GDIM_0) {
......
#section kernels #section kernels
#kernel ave_pool2d_grad_kernel : size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, size, size, *, size : #kernel ave_pool2d_grad_kernel : size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, bool, bool, *, size :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu) // (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads, KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads,
...@@ -11,9 +11,9 @@ KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads, ...@@ -11,9 +11,9 @@ KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads,
const ga_size pad_h, const ga_size pad_w, const ga_bool inc_pad, const ga_bool sum_mode, const ga_size pad_h, const ga_size pad_w, const ga_bool inc_pad, const ga_bool sum_mode,
GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off) GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{ {
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)gz) + gz_off); gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off); gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) { index < nthreads; index += LDIM_0 * GDIM_0) {
...@@ -49,7 +49,7 @@ KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads, ...@@ -49,7 +49,7 @@ KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads,
} }
} }
#kernel ave_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, size, size, size, size, size, *, size : #kernel ave_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, size, size, size, bool, bool, *, size :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu) // (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool3d_grad_kernel(const ga_size nthreads, KERNEL void ave_pool3d_grad_kernel(const ga_size nthreads,
...@@ -62,9 +62,9 @@ KERNEL void ave_pool3d_grad_kernel(const ga_size nthreads, ...@@ -62,9 +62,9 @@ KERNEL void ave_pool3d_grad_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
const ga_bool inc_pad, const ga_bool sum_mode, GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off) const ga_bool inc_pad, const ga_bool sum_mode, GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{ {
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)gz) + gz_off); gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off); gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) { index < nthreads; index += LDIM_0 * GDIM_0) {
......
...@@ -10,10 +10,10 @@ KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads, ...@@ -10,10 +10,10 @@ KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads,
const ga_size pad_h, const ga_size pad_w, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off) GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off)
{ {
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((char *)gx) + gx_off); gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gx) + gx_off);
gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gz) + gz_off); gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gz) + gz_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) { index < nthreads; index += LDIM_0 * GDIM_0) {
...@@ -21,9 +21,9 @@ KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads, ...@@ -21,9 +21,9 @@ KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads,
const ga_size ph = (index / pooled_width) % pooled_height; const ga_size ph = (index / pooled_width) % pooled_height;
const ga_size c = (index / pooled_width / pooled_height) % channels; const ga_size c = (index / pooled_width / pooled_height) % channels;
const ga_size n = (index / pooled_width / pooled_height / channels); const ga_size n = (index / pooled_width / pooled_height / channels);
ga_int hstart = static_cast<ga_int>(ph*stride_h) - static_cast<ga_int>(pad_h); ga_int hstart = (ga_int)(ph*stride_h) - (ga_int)(pad_h);
const ga_size hend = min(hstart + kernel_h, height); const ga_size hend = min(hstart + kernel_h, height);
ga_int wstart = static_cast<ga_int>(pw*stride_w) - static_cast<ga_int>(pad_w); ga_int wstart = (ga_int)(pw*stride_w) - (ga_int)(pad_w);
const ga_size wend = min(wstart + kernel_w, width); const ga_size wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
...@@ -58,10 +58,10 @@ KERNEL void max_pool3d_grad_grad_kernel(const ga_size nthreads, ...@@ -58,10 +58,10 @@ KERNEL void max_pool3d_grad_grad_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off) GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off)
{ {
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((char *)gx) + gx_off); gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gx) + gx_off);
gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gz) + gz_off); gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gz) + gz_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) { index < nthreads; index += LDIM_0 * GDIM_0) {
...@@ -70,11 +70,11 @@ KERNEL void max_pool3d_grad_grad_kernel(const ga_size nthreads, ...@@ -70,11 +70,11 @@ KERNEL void max_pool3d_grad_grad_kernel(const ga_size nthreads,
const ga_size pd = (index / pooled_width / pooled_height) % pooled_depth; const ga_size pd = (index / pooled_width / pooled_height) % pooled_depth;
const ga_size c = (index / pooled_width / pooled_height / pooled_depth) % channels; const ga_size c = (index / pooled_width / pooled_height / pooled_depth) % channels;
const ga_size n = (index / pooled_width / pooled_height / pooled_depth / channels); const ga_size n = (index / pooled_width / pooled_height / pooled_depth / channels);
ga_int dstart = static_cast<ga_int>(pd*stride_d) - static_cast<ga_int>(pad_d); ga_int dstart = (ga_int)(pd*stride_d) - (ga_int)(pad_d);
const ga_size dend = min(dstart + kernel_d, depth); const ga_size dend = min(dstart + kernel_d, depth);
ga_int hstart = static_cast<ga_int>(ph*stride_h) - static_cast<ga_int>(pad_h); ga_int hstart = (ga_int)(ph*stride_h) - (ga_int)(pad_h);
const ga_size hend = min(hstart + kernel_h, height); const ga_size hend = min(hstart + kernel_h, height);
ga_int wstart = static_cast<ga_int>(pw*stride_w) - static_cast<ga_int>(pad_w); ga_int wstart = (ga_int)(pw*stride_w) - (ga_int)(pad_w);
const ga_size wend = min(wstart + kernel_w, width); const ga_size wend = min(wstart + kernel_w, width);
dstart = max(dstart, 0); dstart = max(dstart, 0);
hstart = max(hstart, 0); hstart = max(hstart, 0);
......
...@@ -10,10 +10,10 @@ KERNEL void max_pool2d_grad_kernel(const ga_size nthreads, ...@@ -10,10 +10,10 @@ KERNEL void max_pool2d_grad_kernel(const ga_size nthreads,
const ga_size kernel_h, const ga_size kernel_w, const ga_size stride_h, const ga_size stride_w, const ga_size kernel_h, const ga_size kernel_w, const ga_size stride_h, const ga_size stride_w,
const ga_size pad_h, const ga_size pad_w, GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off) const ga_size pad_h, const ga_size pad_w, GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{ {
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)z) + z_off); z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((char *)gz) + gz_off); gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off); gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) { index < nthreads; index += LDIM_0 * GDIM_0) {
...@@ -55,10 +55,10 @@ KERNEL void max_pool3d_grad_kernel(const ga_size nthreads, ...@@ -55,10 +55,10 @@ KERNEL void max_pool3d_grad_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off) GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{ {
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off); x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)z) + z_off); z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((char *)gz) + gz_off); gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off); gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) { index < nthreads; index += LDIM_0 * GDIM_0) {
......
...@@ -12,9 +12,9 @@ KERNEL void max_pool2d_rop_kernel(const ga_size nthreads, ...@@ -12,9 +12,9 @@ KERNEL void max_pool2d_rop_kernel(const ga_size nthreads,
const ga_size pad_h, const ga_size pad_w, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off) GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{ {
x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((GLOBAL_MEM char *)x) + x_off);
ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((char *)ex) + ex_off); ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((GLOBAL_MEM char *)ex) + ex_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index < nthreads;
...@@ -62,9 +62,9 @@ KERNEL void max_pool3d_rop_kernel(const ga_size nthreads, ...@@ -62,9 +62,9 @@ KERNEL void max_pool3d_rop_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size x_off) GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size x_off)
{ {
x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((char *)x) + x_off); x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((GLOBAL_MEM char *)x) + x_off);
ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((char *)ex) + ex_off); ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((GLOBAL_MEM char *)ex) + ex_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((char *)z) + z_off); z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index < nthreads;
......
...@@ -81,8 +81,8 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -81,8 +81,8 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
const ga_uint Nsamples, const ga_uint Nsamples,
const ga_uint Nstreams_used) const ga_uint Nstreams_used)
{ {
sample_data = (GLOBAL_MEM %(otype)s *)(((char *)sample_data) + sample_offset); sample_data = (GLOBAL_MEM %(otype)s *)(((GLOBAL_MEM char *)sample_data) + sample_offset);
state_data = (GLOBAL_MEM ga_int *)(((char *)state_data) + state_offset); state_data = (GLOBAL_MEM ga_int *)(((GLOBAL_MEM char *)state_data) + state_offset);
/* /*
* The cluda backend makes sure that ga_int corresponds to * The cluda backend makes sure that ga_int corresponds to
* a 32 bit signed type on the target device. It is not a * a 32 bit signed type on the target device. It is not a
...@@ -288,7 +288,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -288,7 +288,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (13,) return (14,)
@register_opt2([mrg_uniform], 'fast_compile') @register_opt2([mrg_uniform], 'fast_compile')
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
*/ */
KERNEL void eye(GLOBAL_MEM DTYPE_OUTPUT_0 *a, ga_size a_off, ga_size n, ga_size m) { KERNEL void eye(GLOBAL_MEM DTYPE_OUTPUT_0 *a, ga_size a_off, ga_size n, ga_size m) {
a = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)a) + a_off); a = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)a) + a_off);
ga_size nb = n < m ? n : m; ga_size nb = n < m ? n : m;
for (ga_size i = LID_0; i < nb; i += LDIM_0) { for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = 1; a[i*m + i] = 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论