提交 e964d5c9 authored 作者: Frederic's avatar Frederic

more pep8

上级 6fcc4739
...@@ -48,8 +48,10 @@ class GpuDot22(GpuOp): ...@@ -48,8 +48,10 @@ class GpuDot22(GpuOp):
%(fail)s; %(fail)s;
} }
if ((NULL == %(z)s) if ((NULL == %(z)s)
|| (CudaNdarray_HOST_DIMS(%(z)s)[0] != CudaNdarray_HOST_DIMS(%(x)s)[0]) || (CudaNdarray_HOST_DIMS(%(z)s)[0] !=
|| (CudaNdarray_HOST_DIMS(%(z)s)[1] != CudaNdarray_HOST_DIMS(%(y)s)[1])) CudaNdarray_HOST_DIMS(%(x)s)[0])
|| (CudaNdarray_HOST_DIMS(%(z)s)[1] !=
CudaNdarray_HOST_DIMS(%(y)s)[1]))
{ {
//if (%(z)s) Py_DECREF(%(z)s); //if (%(z)s) Py_DECREF(%(z)s);
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
...@@ -57,7 +59,8 @@ class GpuDot22(GpuOp): ...@@ -57,7 +59,8 @@ class GpuDot22(GpuOp):
dims[0] = CudaNdarray_HOST_DIMS(%(x)s)[0]; dims[0] = CudaNdarray_HOST_DIMS(%(x)s)[0];
dims[1] = CudaNdarray_HOST_DIMS(%(y)s)[1]; dims[1] = CudaNdarray_HOST_DIMS(%(y)s)[1];
%(z)s = (CudaNdarray*)CudaNdarray_New(); %(z)s = (CudaNdarray*)CudaNdarray_New();
if ((NULL == %(z)s) || CudaNdarray_alloc_contiguous(%(z)s, 2, dims)) if ((NULL == %(z)s) ||
CudaNdarray_alloc_contiguous(%(z)s, 2, dims))
{ {
if (%(z)s) if (%(z)s)
{ {
...@@ -977,10 +980,14 @@ class GpuDownsampleFactorMaxGrad(GpuOp): ...@@ -977,10 +980,14 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
%(fail)s; %(fail)s;
} }
if ((NULL == %(gx)s) if ((NULL == %(gx)s)
|| (CudaNdarray_HOST_DIMS(%(gx)s)[0] != CudaNdarray_HOST_DIMS(%(x)s)[0]) || (CudaNdarray_HOST_DIMS(%(gx)s)[0] !=
|| (CudaNdarray_HOST_DIMS(%(gx)s)[1] != CudaNdarray_HOST_DIMS(%(x)s)[1]) CudaNdarray_HOST_DIMS(%(x)s)[0])
|| (CudaNdarray_HOST_DIMS(%(gx)s)[2] != CudaNdarray_HOST_DIMS(%(x)s)[2]) || (CudaNdarray_HOST_DIMS(%(gx)s)[1] !=
|| (CudaNdarray_HOST_DIMS(%(gx)s)[3] != CudaNdarray_HOST_DIMS(%(x)s)[3])) CudaNdarray_HOST_DIMS(%(x)s)[1])
|| (CudaNdarray_HOST_DIMS(%(gx)s)[2] !=
CudaNdarray_HOST_DIMS(%(x)s)[2])
|| (CudaNdarray_HOST_DIMS(%(gx)s)[3] !=
CudaNdarray_HOST_DIMS(%(x)s)[3]))
{ {
Py_XDECREF(%(gx)s); Py_XDECREF(%(gx)s);
%(gx)s = (CudaNdarray*)CudaNdarray_New(); %(gx)s = (CudaNdarray*)CudaNdarray_New();
...@@ -995,7 +1002,8 @@ class GpuDownsampleFactorMaxGrad(GpuOp): ...@@ -995,7 +1002,8 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
} }
{ {
//TODO: supporting more output columns than threads //TODO: supporting more output columns than threads
// make sure we cover every x row when ignore border isset and there's a border present to be ignored // make sure we cover every x row when ignore border isset and
// there's a border present to be ignored
int needs_extra_z_col = %(ignore_border)s && (CudaNdarray_HOST_DIMS(%(x)s)[2] %% %(ds0)s); int needs_extra_z_col = %(ignore_border)s && (CudaNdarray_HOST_DIMS(%(x)s)[2] %% %(ds0)s);
dim3 grid(CudaNdarray_HOST_DIMS(%(z)s)[0],CudaNdarray_HOST_DIMS(%(z)s)[2] + (needs_extra_z_col ? 1 : 0)); dim3 grid(CudaNdarray_HOST_DIMS(%(z)s)[0],CudaNdarray_HOST_DIMS(%(z)s)[2] + (needs_extra_z_col ? 1 : 0));
dim3 block(std::min(CudaNdarray_HOST_DIMS(%(x)s)[3], 512)); dim3 block(std::min(CudaNdarray_HOST_DIMS(%(x)s)[3], 512));
...@@ -1053,7 +1061,8 @@ class GpuDownsampleFactorMaxGrad(GpuOp): ...@@ -1053,7 +1061,8 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
ignore_border = int(self.ignore_border) ignore_border = int(self.ignore_border)
return """ return """
template<int ds0, int ds1> // ds0 is the downsampling factor in rows, ds1 in columns // ds0 is the downsampling factor in rows, ds1 in columns
template<int ds0, int ds1>
__global__ void kDownsampleMaxGrad_%(nodename)s( __global__ void kDownsampleMaxGrad_%(nodename)s(
int D0, int D1, int D2, int D3, int xD2, int xD3, int D0, int D1, int D2, int D3, int xD2, int xD3,
const float * x, int xS0, int xS1, int xS2, int xS3, const float * x, int xS0, int xS1, int xS2, int xS3,
...@@ -1072,18 +1081,24 @@ class GpuDownsampleFactorMaxGrad(GpuOp): ...@@ -1072,18 +1081,24 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
float cur_max, cur_x, my_z, my_gz; float cur_max, cur_x, my_z, my_gz;
int i0 = blockIdx.x; // image row int i0 = blockIdx.x; // image row
int i1 = 0; // image col int i1 = 0; // image col
int i2 = blockIdx.y; // row wrt z and/or gz, ranges from 0 to D2 - 1 OR D2 (as needed to cover all x rows) // row wrt z and/or gz, ranges from 0 to D2 - 1 OR D2
// (as needed to cover all x rows)
int i2 = blockIdx.y;
int x_col = threadIdx.x; // col wrt x, ranges from 0 to xD3 - 1 int x_col = threadIdx.x; // col wrt x, ranges from 0 to xD3 - 1
int z_col = x_col/ds1; // z_col corresponding to this x_col int z_col = x_col/ds1; // z_col corresponding to this x_col
//TODO: raise occupancy. Use threadIdx.y to run several iterations of this i1 loop //TODO: raise occupancy. Use threadIdx.y to run several
//in parallel // iterations of this i1 loop in parallel
for (i1 = 0; i1 < D1; ++i1) // loop over images (same for z and x) for (i1 = 0; i1 < D1; ++i1) // loop over images (same for z and x)
{ {
for(int col_iter = 0; col_iter * blockDim.x <= xD3 ; col_iter++){ for(int col_iter = 0;
//The if inside is to don't do the division if we need only 1 col_iter col_iter * blockDim.x <= xD3 ; col_iter++){
//The if inside is to don't do the division if we
// need only 1 col_iter
if(blockDim.x != xD3) if(blockDim.x != xD3)
{ {
x_col = threadIdx.x + col_iter * blockDim.x; x_col = threadIdx.x + col_iter * blockDim.x;
...@@ -1092,32 +1107,42 @@ class GpuDownsampleFactorMaxGrad(GpuOp): ...@@ -1092,32 +1107,42 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
if (%(ignore_border)s && x_col >= ds1 * D3) if (%(ignore_border)s && x_col >= ds1 * D3)
{ {
// This happens only if x_col was ignored (via ignore_border) // This happens only if x_col was ignored
// TODO: if ignore_border is False, this is impossible and we don't even // (via ignore_border)
// need to generate this code. // TODO: if ignore_border is False, this is impossible
// and we don't even need to generate this code.
my_gz = 0.0f; my_gz = 0.0f;
//any fp number suffices for my_z, so we don't even need to set it to
//anything in particular. //any fp number suffices for my_z, so we don't even
//need to set it to anything in particular.
} }
else else
{ {
// this is effectively: // this is effectively:
// my_gz = gz[image_row][image_col][z_row][z_col] // my_gz = gz[image_row][image_col][z_row][z_col]
// my_z = z[image_row][image_col][z_row][z_col] // my_z = z[image_row][image_col][z_row][z_col]
my_gz = gz[i0 * gzS0 + i1 * gzS1 + i2 * gzS2 + z_col*gzS3]; my_gz = gz[i0 * gzS0 + i1 * gzS1 + i2 * gzS2 +
my_z = z[i0 * zS0 + i1 * zS1 + i2 * zS2 + z_col* zS3]; z_col*gzS3];
my_z = z[i0 * zS0 + i1 * zS1 + i2 * zS2 +
z_col* zS3];
} }
if(x_col<xD3){ if(x_col<xD3){
for (int x_row = i2*ds0; (x_row < i2*ds0+ds0) && (x_row < xD2); ++x_row) for (int x_row = i2*ds0;
(x_row < i2*ds0+ds0) && (x_row < xD2); ++x_row)
{ {
// this is effectively: // this is effectively:
// gx[image_row][image_col][x_row][x_col] // gx[image_row][image_col][x_row][x_col]
// = (my_z == x[image_row][image_col][x_row][x_col]) ? my_gz : 0.0f; // = (my_z == x[image_row][image_col][
gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 + x_row*xD3 + x_col] // x_row][x_col]) ? my_gz : 0.0f;
= (my_z == x[i0*xS0 + i1*xS1 + x_row*xS2 + x_col*xS3]) ? my_gz : 0.0f; gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 +
x_row*xD3 + x_col]
= (my_z == x[i0*xS0 + i1*xS1 + x_row*xS2 +
x_col*xS3]) ? my_gz : 0.0f;
} }
//gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 + x_row*xD3 + x_col] = -999; //gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 +
x_row*xD3 + x_col] = -999;
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论