提交 413f1b91 authored 作者: xiaoqie's avatar xiaoqie

Fix parameter types in ave_pool kernels, remove static_cast, add GLOBAL_MEM

上级 7509fa75
......@@ -1632,7 +1632,7 @@ class GpuEye(GpuKernelBase, Op):
code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size a_off,
ga_size n, ga_size m, ga_ssize k) {
a = (GLOBAL_MEM %(ctype)s *)(((char *)a) + a_off);
a = (GLOBAL_MEM %(ctype)s *)(((GLOBAL_MEM char *)a) + a_off);
ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff);
......
......@@ -47,8 +47,8 @@ KERNEL void dilated_im3d2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......@@ -100,8 +100,8 @@ KERNEL void im3d2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......@@ -115,10 +115,10 @@ KERNEL void im3d2col_kernel(const ga_size n,
const ga_size h_offset = h_col * stride_h - pad_h;
const ga_size w_offset = w_col * stride_w - pad_w;
const ga_size d_offset = d_col * stride_d - pad_d;
DTYPE_INPUT_0 * data_col_ptr = data_col;
GLOBAL_MEM DTYPE_INPUT_0 * data_col_ptr = data_col;
data_col_ptr += c_col * (height_col * width_col * depth_col) +
h_col * (width_col * depth_col) + w_col * depth_col + d_col;
const DTYPE_INPUT_0 * data_im_ptr = data_im + data_im_offset;
GLOBAL_MEM const DTYPE_INPUT_0 * data_im_ptr = data_im + data_im_offset;
data_im_ptr += c_im * (height * width * depth) +
h_offset * (width * depth) + w_offset * depth + d_offset;
for (ga_size i = 0; i < kernel_h; ++i) {
......@@ -154,8 +154,8 @@ KERNEL void dilated_col2im3d_kernel(const ga_size n,
const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......@@ -222,8 +222,8 @@ KERNEL void col2im3d_kernel(const ga_size n,
const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......
......@@ -47,8 +47,8 @@ KERNEL void dilated_im2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......@@ -91,8 +91,8 @@ KERNEL void im2col_kernel(const ga_size n,
const ga_size height_col, const ga_size width_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......@@ -135,8 +135,8 @@ KERNEL void dilated_col2im_kernel(const ga_size n,
const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......@@ -185,8 +185,8 @@ KERNEL void col2im_kernel(const ga_size n,
const ga_size data_im_offset) {
// offset_im is the pointer offset for data_im.
// data_im_offset is an offset of elements in the array
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_col) + offset_col);
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)data_im) + offset_im);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) {
......
......@@ -84,9 +84,9 @@ KERNEL void k_multi_warp_multinomial(
const ga_ssize outs_col_stride
)
{
global_pvals = (GLOBAL_MEM %(in_ctype)s *)(((char *)global_pvals) + global_pvals_offset);
global_unis = (GLOBAL_MEM %(in_ctype)s *)(((char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM %(out_ctype)s *)(((char *)global_outs) + global_outs_offset);
global_pvals = (GLOBAL_MEM %(in_ctype)s *)(((GLOBAL_MEM char *)global_pvals) + global_pvals_offset);
global_unis = (GLOBAL_MEM %(in_ctype)s *)(((GLOBAL_MEM char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM %(out_ctype)s *)(((GLOBAL_MEM char *)global_outs) + global_outs_offset);
// each thread takes care of one multinomial draw
int n = LDIM_0*GID_0 + LID_0;
if (n < nb_multi)
......@@ -297,9 +297,9 @@ KERNEL void k_multi_warp_multinomial_wor(
const ga_ssize outs_col_stride
)
{
global_pvals_copy = (GLOBAL_MEM float *)(((char *)global_pvals_copy) + global_pvals_offset);
global_unis = (GLOBAL_MEM float *)(((char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM ga_long *)(((char *)global_outs) + global_outs_offset);
global_pvals_copy = (GLOBAL_MEM float *)(((GLOBAL_MEM char *)global_pvals_copy) + global_pvals_offset);
global_unis = (GLOBAL_MEM float *)(((GLOBAL_MEM char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM ga_long *)(((GLOBAL_MEM char *)global_outs) + global_outs_offset);
// each thread takes care of one multinomial-wor n_samples-draw
int n = LDIM_0*GID_0 + LID_0;
......
......@@ -10,8 +10,8 @@ KERNEL void max_pool2d_kernel(const ga_size nthreads,
const ga_size stride_h, const ga_size stride_w, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads;
......@@ -55,8 +55,8 @@ KERNEL void max_pool3d_kernel(const ga_size nthreads,
const ga_size stride_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads;
......@@ -94,7 +94,7 @@ KERNEL void max_pool3d_kernel(const ga_size nthreads,
}
}
#kernel ave_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, *, size:
#kernel ave_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, bool, bool, *, size:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool2d_kernel(const ga_size nthreads,
......@@ -105,8 +105,8 @@ KERNEL void ave_pool2d_kernel(const ga_size nthreads,
const ga_bool inc_pad, const ga_bool sum_mode,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads;
......@@ -149,7 +149,7 @@ KERNEL void ave_pool2d_kernel(const ga_size nthreads,
}
}
#kernel ave_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, size, size, size, *, size :
#kernel ave_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, size, bool, bool, *, size :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool3d_kernel(const ga_size nthreads,
......@@ -163,8 +163,8 @@ KERNEL void ave_pool3d_kernel(const ga_size nthreads,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{
// grid stride looping
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)z) + z_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)z) + z_off);
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads;
index += LDIM_0 * GDIM_0) {
......
#section kernels
#kernel ave_pool2d_grad_kernel : size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, size, size, *, size :
#kernel ave_pool2d_grad_kernel : size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, bool, bool, *, size :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads,
......@@ -11,9 +11,9 @@ KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads,
const ga_size pad_h, const ga_size pad_w, const ga_bool inc_pad, const ga_bool sum_mode,
GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off);
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) {
......@@ -49,7 +49,7 @@ KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads,
}
}
#kernel ave_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, size, size, size, size, size, *, size :
#kernel ave_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *, size, *, size, size, size, size, size, size, size, size, size, size, bool, bool, *, size :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL void ave_pool3d_grad_kernel(const ga_size nthreads,
......@@ -62,9 +62,9 @@ KERNEL void ave_pool3d_grad_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
const ga_bool inc_pad, const ga_bool sum_mode, GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off);
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) {
......
......@@ -10,10 +10,10 @@ KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads,
const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off)
{
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((char *)z) + z_off);
gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((char *)gx) + gx_off);
gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gz) + gz_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gx) + gx_off);
gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gz) + gz_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) {
......@@ -21,9 +21,9 @@ KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads,
const ga_size ph = (index / pooled_width) % pooled_height;
const ga_size c = (index / pooled_width / pooled_height) % channels;
const ga_size n = (index / pooled_width / pooled_height / channels);
ga_int hstart = static_cast<ga_int>(ph*stride_h) - static_cast<ga_int>(pad_h);
ga_int hstart = (ga_int)(ph*stride_h) - (ga_int)(pad_h);
const ga_size hend = min(hstart + kernel_h, height);
ga_int wstart = static_cast<ga_int>(pw*stride_w) - static_cast<ga_int>(pad_w);
ga_int wstart = (ga_int)(pw*stride_w) - (ga_int)(pad_w);
const ga_size wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
......@@ -58,10 +58,10 @@ KERNEL void max_pool3d_grad_grad_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off)
{
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((char *)z) + z_off);
gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((char *)gx) + gx_off);
gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gz) + gz_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gx) + gx_off);
gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gz) + gz_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) {
......@@ -70,11 +70,11 @@ KERNEL void max_pool3d_grad_grad_kernel(const ga_size nthreads,
const ga_size pd = (index / pooled_width / pooled_height) % pooled_depth;
const ga_size c = (index / pooled_width / pooled_height / pooled_depth) % channels;
const ga_size n = (index / pooled_width / pooled_height / pooled_depth / channels);
ga_int dstart = static_cast<ga_int>(pd*stride_d) - static_cast<ga_int>(pad_d);
ga_int dstart = (ga_int)(pd*stride_d) - (ga_int)(pad_d);
const ga_size dend = min(dstart + kernel_d, depth);
ga_int hstart = static_cast<ga_int>(ph*stride_h) - static_cast<ga_int>(pad_h);
ga_int hstart = (ga_int)(ph*stride_h) - (ga_int)(pad_h);
const ga_size hend = min(hstart + kernel_h, height);
ga_int wstart = static_cast<ga_int>(pw*stride_w) - static_cast<ga_int>(pad_w);
ga_int wstart = (ga_int)(pw*stride_w) - (ga_int)(pad_w);
const ga_size wend = min(wstart + kernel_w, width);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
......
......@@ -10,10 +10,10 @@ KERNEL void max_pool2d_grad_kernel(const ga_size nthreads,
const ga_size kernel_h, const ga_size kernel_w, const ga_size stride_h, const ga_size stride_w,
const ga_size pad_h, const ga_size pad_w, GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)z) + z_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off);
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) {
......@@ -55,10 +55,10 @@ KERNEL void max_pool3d_grad_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *gx, const ga_size gx_off)
{
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((char *)x) + x_off);
z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((char *)z) + z_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)gx) + gx_off);
x = (GLOBAL_MEM const DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
z = (GLOBAL_MEM const DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
gz = (GLOBAL_MEM const DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gz) + gz_off);
gx = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gx) + gx_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads; index += LDIM_0 * GDIM_0) {
......
......@@ -12,9 +12,9 @@ KERNEL void max_pool2d_rop_kernel(const ga_size nthreads,
const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size z_off)
{
x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((char *)x) + x_off);
ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((char *)ex) + ex_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((char *)z) + z_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((GLOBAL_MEM char *)x) + x_off);
ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((GLOBAL_MEM char *)ex) + ex_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads;
......@@ -62,9 +62,9 @@ KERNEL void max_pool3d_rop_kernel(const ga_size nthreads,
const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
GLOBAL_MEM DTYPE_OUTPUT_0 *z, const ga_size x_off)
{
x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((char *)x) + x_off);
ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((char *)ex) + ex_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((char *)z) + z_off);
x = (GLOBAL_MEM DTYPE_INPUT_0 *x)(((GLOBAL_MEM char *)x) + x_off);
ex = (GLOBAL_MEM DTYPE_INPUT_1 *x)(((GLOBAL_MEM char *)ex) + ex_off);
z = (GLOBAL_MEM DTYPE_OUTPUT_0 *x)(((GLOBAL_MEM char *)z) + z_off);
// grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < nthreads;
......
......@@ -81,8 +81,8 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
const ga_uint Nsamples,
const ga_uint Nstreams_used)
{
sample_data = (GLOBAL_MEM %(otype)s *)(((char *)sample_data) + sample_offset);
state_data = (GLOBAL_MEM ga_int *)(((char *)state_data) + state_offset);
sample_data = (GLOBAL_MEM %(otype)s *)(((GLOBAL_MEM char *)sample_data) + sample_offset);
state_data = (GLOBAL_MEM ga_int *)(((GLOBAL_MEM char *)state_data) + state_offset);
/*
* The cluda backend makes sure that ga_int corresponds to
* a 32 bit signed type on the target device. It is not a
......
......@@ -8,7 +8,7 @@
*/
KERNEL void eye(GLOBAL_MEM DTYPE_OUTPUT_0 *a, ga_size a_off, ga_size n, ga_size m) {
a = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((char *)a) + a_off);
a = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)a) + a_off);
ga_size nb = n < m ? n : m;
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论