提交 2d39527c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix C code of GpuDownsampleFactorMaxGrad.

It was causing errors when running with cuda-memcheck, or when running on Mac OS X.
上级 cc4fc100
......@@ -932,7 +932,7 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
return Apply(self, [x, z, gz], [x.type()])
def c_code_cache_version(self):
return (7,)
return (8,)
def c_code(self, node, nodename, inp, out, sub):
x, z, gz = inp
......@@ -1056,6 +1056,7 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
// Cast threadIdx.x into a signed int, to avoid problems with
// indexing with negative offsets.
int tx = threadIdx.x;
int bx = blockDim.x;
for(int i0 = blockIdx.x;
i0 < D0;
......@@ -1075,20 +1076,20 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
for (i1 = 0; i1 < D1; ++i1) // loop over images (same for z and x)
{
for(int col_iter = 0;
col_iter * blockDim.x <= xD3 ; col_iter++){
(tx + col_iter * bx < xD3) ; col_iter++){
//The if inside is to don't do the division if we
// need only 1 col_iter
if(blockDim.x != xD3)
if(tx + bx < xD3)
{
x_col = tx + col_iter * blockDim.x;
x_col = tx + col_iter * bx;
z_col = x_col/ds1;
}
if (%(ignore_border)s && x_col >= ds1 * D3)
if (%(ignore_border)s && ((x_col >= ds1 * D3) || (i2 >= D2)))
{
// This happens only if x_col was ignored
// This happens only if x_col was ignored, or if i2*ds0 was
// (via ignore_border)
// TODO: if ignore_border is False, this is impossible
// and we don't even need to generate this code.
......@@ -1109,18 +1110,16 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
my_z = z[i0 * zS0 + i1 * zS1 + i2 * zS2 +
z_col* zS3];
}
if(x_col<xD3){
for (int x_row = i2*ds0;
(x_row < i2*ds0+ds0) && (x_row < xD2); ++x_row)
{
// this is effectively:
// gx[image_row][image_col][x_row][x_col]
// = (my_z == x[image_row][image_col][
// x_row][x_col]) ? my_gz : 0.0f;
gx[i0*gxS0 + i1*gxS1 + x_row*gxS2 + x_col*gxS3]
= (my_z == x[i0*xS0 + i1*xS1 + x_row*xS2 +
x_col*xS3]) ? my_gz : 0.0f;
}
for (int x_row = i2*ds0;
(x_row < i2*ds0+ds0) && (x_row < xD2); ++x_row)
{
// this is effectively:
// gx[image_row][image_col][x_row][x_col]
// = (my_z == x[image_row][image_col][
// x_row][x_col]) ? my_gz : 0.0f;
gx[i0*gxS0 + i1*gxS1 + x_row*gxS2 + x_col*gxS3]
= (my_z == x[i0*xS0 + i1*xS1 + x_row*xS2 +
x_col*xS3]) ? my_gz : 0.0f;
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论