Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9cedf22c
提交
9cedf22c
authored
11月 08, 2016
作者:
Alexander Matyasko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add gpu max and average pooling gradient
上级
ea45d835
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
487 行增加
和
0 行删除
+487
-0
blas.py
theano/gpuarray/blas.py
+97
-0
pool_ave_grad.c
theano/gpuarray/pool_ave_grad.c
+203
-0
pool_max_grad.c
theano/gpuarray/pool_max_grad.c
+187
-0
没有找到文件。
theano/gpuarray/blas.py
浏览文件 @
9cedf22c
...
@@ -1602,6 +1602,103 @@ class GpuPool(CGpuKernelBase):
...
@@ -1602,6 +1602,103 @@ class GpuPool(CGpuKernelBase):
(
'SUM_MODE'
,
sum_mode
)]
(
'SUM_MODE'
,
sum_mode
)]
class
GpuMaxPoolGrad
(
CGpuKernelBase
):
"""
Implement the grad of max pooling on the gpu.
"""
__props__
=
(
'ignore_border'
,
'mode'
,
'ndim'
)
def
__init__
(
self
,
ignore_border
,
mode
=
'max'
,
ndim
=
2
):
self
.
ndim
=
ndim
self
.
ignore_border
=
ignore_border
self
.
mode
=
mode
CGpuKernelBase
.
__init__
(
self
,
[
'pool_max_grad.c'
],
'APPLY_SPECIFIC(max_pool_grad)'
)
assert
mode
==
'max'
assert
ndim
in
[
2
,
3
]
def
c_headers
(
self
):
return
[
'gpuarray_api.h'
,
'gpuarray_helper.h'
,
'numpy_compat.h'
]
def
c_header_dirs
(
self
):
return
[
os
.
path
.
dirname
(
__file__
),
pygpu
.
get_include
()]
def
make_node
(
self
,
inp
,
out
,
out_grad
,
ws
,
stride
,
pad
):
ctx_name
=
infer_context_name
(
inp
,
out
,
out_grad
)
inp
=
as_gpuarray_variable
(
inp
,
ctx_name
)
assert
(
inp
.
ndim
in
[
4
,
5
])
out
=
as_gpuarray_variable
(
out
,
ctx_name
)
assert
(
out
.
ndim
in
[
4
,
5
])
out_grad
=
as_gpuarray_variable
(
out_grad
,
ctx_name
)
assert
(
out_grad
.
ndim
in
[
4
,
5
])
assert
(
out_grad
.
ndim
==
inp
.
ndim
)
assert
(
inp
.
ndim
==
out
.
ndim
)
ws
=
as_tensor_variable
(
ws
)
stride
=
as_tensor_variable
(
stride
)
pad
=
as_tensor_variable
(
pad
)
assert
ws
.
type
.
ndim
==
stride
.
type
.
ndim
and
ws
.
type
.
ndim
==
pad
.
type
.
ndim
assert
ws
.
type
.
ndim
==
1
return
Apply
(
self
,
[
inp
,
out
,
out_grad
,
ws
,
stride
,
pad
],
[
inp
.
type
()])
def
get_params
(
self
,
node
):
return
node
.
inputs
[
0
]
.
type
.
context
class
GpuAveragePoolGrad
(
CGpuKernelBase
):
"""
Implement the grad of average pooling on the gpu.
"""
__props__
=
(
'ignore_border'
,
'mode'
,
'ndim'
)
def
__init__
(
self
,
ignore_border
,
mode
=
'max'
,
ndim
=
2
):
self
.
ndim
=
ndim
self
.
ignore_border
=
ignore_border
if
mode
==
'average'
:
mode
=
'average_inc_pad'
self
.
mode
=
mode
CGpuKernelBase
.
__init__
(
self
,
[
'pool_ave_grad.c'
],
'APPLY_SPECIFIC(ave_pool_grad)'
)
assert
mode
in
(
'sum'
,
'average_inc_pad'
,
'average_exc_pad'
)
assert
ndim
in
[
2
,
3
]
def
c_headers
(
self
):
return
[
'gpuarray_api.h'
,
'gpuarray_helper.h'
,
'numpy_compat.h'
]
def
c_header_dirs
(
self
):
return
[
os
.
path
.
dirname
(
__file__
),
pygpu
.
get_include
()]
def
make_node
(
self
,
inp
,
out_grad
,
ws
,
stride
,
pad
):
ctx_name
=
infer_context_name
(
inp
,
out_grad
)
inp
=
as_gpuarray_variable
(
inp
,
ctx_name
)
assert
(
inp
.
ndim
in
[
4
,
5
])
out_grad
=
as_gpuarray_variable
(
out_grad
,
ctx_name
)
assert
(
out_grad
.
ndim
in
[
4
,
5
])
assert
(
out_grad
.
ndim
==
inp
.
ndim
)
ws
=
as_tensor_variable
(
ws
)
stride
=
as_tensor_variable
(
stride
)
pad
=
as_tensor_variable
(
pad
)
assert
ws
.
type
.
ndim
==
stride
.
type
.
ndim
and
ws
.
type
.
ndim
==
pad
.
type
.
ndim
assert
ws
.
type
.
ndim
==
1
return
Apply
(
self
,
[
inp
,
out_grad
,
ws
,
stride
,
pad
],
[
inp
.
type
()])
def
get_params
(
self
,
node
):
return
node
.
inputs
[
0
]
.
type
.
context
def
get_op_params
(
self
):
inc_pad
=
int
(
self
.
mode
==
'average_inc_pad'
)
sum_mode
=
int
(
self
.
mode
==
'sum'
)
return
[(
'INC_PAD'
,
inc_pad
),
(
'SUM_MODE'
,
sum_mode
)]
class
GpuDownsampleFactorMaxGradGrad
(
CGpuKernelBase
):
class
GpuDownsampleFactorMaxGradGrad
(
CGpuKernelBase
):
"""
"""
Implement the grad of downsample with max on the gpu.
Implement the grad of downsample with max on the gpu.
...
...
theano/gpuarray/pool_ave_grad.c
0 → 100644
浏览文件 @
9cedf22c
#section kernels
#kernel ave_pool2d_grad_kernel : size, size, size, size, size, size, size, *, *, size, size, size, size, size, size, size, size, * :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
ave_pool2d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
gz
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
const
ga_bool
inc_pad
,
const
ga_bool
sum_mode
,
GLOBAL_MEM
DTYPE_o0
*
gx
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
w
=
index
%
width
;
const
ga_size
h
=
(
index
/
width
)
%
height
;
const
ga_size
c
=
(
index
/
width
/
height
)
%
channels
;
const
ga_size
n
=
(
index
/
width
/
height
/
channels
);
const
ga_size
phstart
=
(
h
+
pad_h
<
kernel_h
)
?
0
:
(
h
+
pad_h
-
kernel_h
)
/
stride_h
+
1
;
const
ga_size
phend
=
min
((
h
+
pad_h
)
/
stride_h
+
1
,
pooled_height
);
const
ga_size
pwstart
=
(
w
+
pad_w
<
kernel_w
)
?
0
:
(
w
+
pad_w
-
kernel_w
)
/
stride_w
+
1
;
const
ga_size
pwend
=
min
((
w
+
pad_w
)
/
stride_w
+
1
,
pooled_width
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
pooled_height
*
pooled_width
;
const
DTYPE_i1
*
gz_slice
=
gz
+
offset
;
DTYPE_o0
collector
=
0
;
for
(
ga_size
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
ga_size
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
sum_mode
)
{
collector
+=
gz
[
ph
*
pooled_width
+
pw
];
}
else
{
// figure out the pooling size
const
ga_size
hstart
=
ph
*
stride_h
-
pad_h
;
const
ga_size
wstart
=
pw
*
stride_w
-
pad_w
;
const
ga_size
hend
=
min
(
hstart
+
kernel_h
,
height
+
pad_h
);
const
ga_size
wend
=
min
(
wstart
+
kernel_w
,
width
+
pad_w
);
const
ga_size
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
collector
+=
gz_slice
[
ph
*
pooled_width
+
pw
]
/
pool_size
;
}
}
}
gx
[
index
]
=
collector
;
}
}
#kernel ave_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *, *, size, size, size, size, size, size, size, size, size, size, size, * :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
ave_pool3d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
gz
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
const
ga_bool
inc_pad
,
const
ga_bool
sum_mode
,
GLOBAL_MEM
DTYPE_o0
*
gx
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
w
=
index
%
width
;
const
ga_size
h
=
(
index
/
width
)
%
height
;
const
ga_size
d
=
(
index
/
width
/
height
)
%
depth
;
const
ga_size
c
=
(
index
/
width
/
height
/
depth
)
%
channels
;
const
ga_size
n
=
(
index
/
width
/
height
/
depth
/
channels
);
const
ga_size
pdstart
=
(
d
+
pad_d
<
kernel_d
)
?
0
:
(
d
+
pad_d
-
kernel_d
)
/
stride_d
+
1
;
const
ga_size
pdend
=
min
((
d
+
pad_d
)
/
stride_d
+
1
,
pooled_depth
);
const
ga_size
phstart
=
(
h
+
pad_h
<
kernel_h
)
?
0
:
(
h
+
pad_h
-
kernel_h
)
/
stride_h
+
1
;
const
ga_size
phend
=
min
((
h
+
pad_h
)
/
stride_h
+
1
,
pooled_height
);
const
ga_size
pwstart
=
(
w
+
pad_w
<
kernel_w
)
?
0
:
(
w
+
pad_w
-
kernel_w
)
/
stride_w
+
1
;
const
ga_size
pwend
=
min
((
w
+
pad_w
)
/
stride_w
+
1
,
pooled_width
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
pooled_depth
*
pooled_height
*
pooled_width
;
const
DTYPE_i1
*
gz_slice
=
gz
+
offset
;
DTYPE_o0
collector
=
0
;
for
(
ga_size
pd
=
pdstart
;
pd
<
pdend
;
++
pd
)
{
for
(
ga_size
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
ga_size
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
sum_mode
)
{
collector
+=
gz
[
ph
*
pooled_width
+
pw
];
}
else
{
// figure out the pooling size
const
ga_size
dstart
=
pd
*
stride_d
-
pad_d
;
const
ga_size
hstart
=
ph
*
stride_h
-
pad_h
;
const
ga_size
wstart
=
pw
*
stride_w
-
pad_w
;
const
ga_size
dend
=
min
(
dstart
+
kernel_h
,
depth
+
pad_d
);
const
ga_size
hend
=
min
(
hstart
+
kernel_h
,
height
+
pad_h
);
const
ga_size
wend
=
min
(
wstart
+
kernel_w
,
width
+
pad_w
);
const
ga_size
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
collector
+=
gz
[
ph
*
pooled_width
+
pw
]
/
pool_size
;
}
}
}
}
gx
[
index
]
=
collector
;
}
}
#section support_code
// CUDA: number of blocks for threads.
inline
int
GET_BLOCKS
(
const
int
nkernels
,
const
int
nthreads
)
{
return
(
nkernels
+
nthreads
-
1
)
/
nthreads
;
}
#section support_code_struct
int
APPLY_SPECIFIC
(
ave_pool_grad
)(
PyGpuArrayObject
*
x
,
PyGpuArrayObject
*
gz
,
PyArrayObject
*
ws
,
PyArrayObject
*
stride
,
PyArrayObject
*
pad
,
PyGpuArrayObject
**
gx
,
PyGpuContextObject
*
ctx
)
{
if
(
!
GpuArray_IS_C_CONTIGUOUS
(
&
x
->
ga
)
||
!
GpuArray_IS_C_CONTIGUOUS
(
&
gz
->
ga
))
{
PyErr_Format
(
PyExc_ValueError
,
"GpuMaxPoolGrad: requires data to be C-contiguous"
);
return
1
;
}
size_t
ndims
=
PyArray_DIM
(
ws
,
0
);
if
(
PyGpuArray_NDIM
(
x
)
!=
ndims
+
2
||
PyGpuArray_NDIM
(
gz
)
!=
ndims
+
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"GpuMaxPoolGrad: rank error"
);
return
1
;
}
if
(
theano_prep_output
(
gx
,
PyGpuArray_NDIM
(
x
),
PyGpuArray_DIMS
(
x
),
x
->
ga
.
typecode
,
GA_C_ORDER
,
ctx
)
!=
0
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMaxPoolGrad: failed to allocate memory"
);
return
1
;
}
{
// scope for running kernel
size_t
w
[
3
];
size_t
s
[
3
];
size_t
p
[
3
];
for
(
int
i
=
0
;
i
<
ndims
;
i
++
)
{
w
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
ws
,
i
));
s
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
stride
,
i
));
p
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
pad
,
i
));
}
size_t
max_threads_dim
;
int
err
;
const
size_t
*
z_dims
=
PyGpuArray_DIMS
(
gz
);
const
size_t
*
x_dims
=
PyGpuArray_DIMS
(
x
);
// Get the max threads per blocks
err
=
gpucontext_property
(
ctx
->
ctx
,
GA_CTX_PROP_MAXLSIZE0
,
&
max_threads_dim
);
if
(
err
!=
GA_NO_ERROR
){
PyErr_SetString
(
PyExc_RuntimeError
,
"Could not fetch max_threads_dims"
);
return
1
;
}
size_t
threads_per_block
=
max_threads_dim
;
if
(
ndims
==
2
)
{
size_t
num_kernels
=
x_dims
[
0
]
*
x_dims
[
1
]
*
x_dims
[
2
]
*
x_dims
[
3
];
size_t
n_blocks
=
GET_BLOCKS
(
num_kernels
,
threads_per_block
);
err
=
ave_pool2d_grad_kernel_call
(
1
,
&
n_blocks
,
&
threads_per_block
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
z_dims
[
2
],
z_dims
[
3
],
x
->
ga
.
data
,
gz
->
ga
.
data
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
INC_PAD
,
SUM_MODE
,
(
*
gx
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuAveragePoolGrad: ave_pool2d_grad_kernel %s."
,
GpuKernel_error
(
&
k_ave_pool2d_grad_kernel
,
err
));
return
1
;
}
}
else
if
(
ndims
==
3
)
{
size_t
num_kernels
=
x_dims
[
0
]
*
x_dims
[
1
]
*
x_dims
[
2
]
*
x_dims
[
3
]
*
x_dims
[
4
];
size_t
n_blocks
=
GET_BLOCKS
(
num_kernels
,
threads_per_block
);
err
=
ave_pool3d_grad_kernel_call
(
1
,
&
n_blocks
,
&
threads_per_block
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x
->
ga
.
data
,
gz
->
ga
.
data
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
INC_PAD
,
SUM_MODE
,
(
*
gx
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuAveragePoolGrad: ave_pool3d_grad_kernel %s."
,
GpuKernel_error
(
&
k_ave_pool3d_grad_kernel
,
err
));
return
1
;
}
}
}
return
0
;
}
theano/gpuarray/pool_max_grad.c
0 → 100644
浏览文件 @
9cedf22c
#section kernels
#kernel max_pool2d_grad_kernel : size, size, size, size, size, size, size, *, *, *, size, size, size, size, size, size, * :
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool2d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
z
,
GLOBAL_MEM
const
DTYPE_i2
*
gz
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_o0
*
gx
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
w
=
index
%
width
;
const
ga_size
h
=
(
index
/
width
)
%
height
;
const
ga_size
c
=
(
index
/
width
/
height
)
%
channels
;
const
ga_size
n
=
(
index
/
width
/
height
/
channels
);
const
ga_size
phstart
=
(
h
+
pad_h
<
kernel_h
)
?
0
:
(
h
+
pad_h
-
kernel_h
)
/
stride_h
+
1
;
const
ga_size
phend
=
min
((
h
+
pad_h
)
/
stride_h
+
1
,
pooled_height
);
const
ga_size
pwstart
=
(
w
+
pad_w
<
kernel_w
)
?
0
:
(
w
+
pad_w
-
kernel_w
)
/
stride_w
+
1
;
const
ga_size
pwend
=
min
((
w
+
pad_w
)
/
stride_w
+
1
,
pooled_width
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
pooled_height
*
pooled_width
;
const
DTYPE_i1
*
z_slice
=
z
+
offset
;
const
DTYPE_i2
*
gz_slice
=
gz
+
offset
;
DTYPE_o0
gradient
=
0
;
for
(
ga_size
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
ga_size
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
x
[
index
]
==
z_slice
[
ph
*
pooled_width
+
pw
])
{
gradient
+=
gz_slice
[
ph
*
pooled_width
+
pw
];
}
}
}
gx
[
index
]
=
gradient
;
}
}
#kernel max_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *, *, *, size, size, size, size, size, size, size, size, size, * :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool3d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
z
,
GLOBAL_MEM
const
DTYPE_i2
*
gz
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_o0
*
gx
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
w
=
index
%
width
;
const
ga_size
h
=
(
index
/
width
)
%
height
;
const
ga_size
d
=
(
index
/
width
/
height
)
%
depth
;
const
ga_size
c
=
(
index
/
width
/
height
/
depth
)
%
channels
;
const
ga_size
n
=
(
index
/
width
/
height
/
depth
/
channels
);
const
ga_size
pdstart
=
(
d
+
pad_d
<
kernel_d
)
?
0
:
(
d
+
pad_d
-
kernel_d
)
/
stride_d
+
1
;
const
ga_size
pdend
=
min
((
d
+
pad_d
)
/
stride_d
+
1
,
pooled_depth
);
const
ga_size
phstart
=
(
h
+
pad_h
<
kernel_h
)
?
0
:
(
h
+
pad_h
-
kernel_h
)
/
stride_h
+
1
;
const
ga_size
phend
=
min
((
h
+
pad_h
)
/
stride_h
+
1
,
pooled_height
);
const
ga_size
pwstart
=
(
w
+
pad_w
<
kernel_w
)
?
0
:
(
w
+
pad_w
-
kernel_w
)
/
stride_w
+
1
;
const
ga_size
pwend
=
min
((
w
+
pad_w
)
/
stride_w
+
1
,
pooled_width
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
pooled_depth
*
pooled_height
*
pooled_width
;
const
DTYPE_i1
*
z_slice
=
z
+
offset
;
const
DTYPE_i2
*
gz_slice
=
gz
+
offset
;
DTYPE_o0
gradient
=
0
;
for
(
ga_size
pd
=
pdstart
;
pd
<
pdend
;
++
pd
)
{
for
(
ga_size
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
ga_size
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
x
[
index
]
==
z_slice
[(
pd
*
pooled_height
+
ph
)
*
pooled_width
+
pw
])
{
gradient
+=
gz_slice
[(
pd
*
pooled_height
+
ph
)
*
pooled_width
+
pw
];
}
}
}
}
gx
[
index
]
=
gradient
;
}
}
#section support_code
// CUDA: number of blocks for threads.
inline
int
GET_BLOCKS
(
const
int
nkernels
,
const
int
nthreads
)
{
return
(
nkernels
+
nthreads
-
1
)
/
nthreads
;
}
#section support_code_struct
int
APPLY_SPECIFIC
(
max_pool_grad
)(
PyGpuArrayObject
*
x
,
PyGpuArrayObject
*
z
,
PyGpuArrayObject
*
gz
,
PyArrayObject
*
ws
,
PyArrayObject
*
stride
,
PyArrayObject
*
pad
,
PyGpuArrayObject
**
gx
,
PyGpuContextObject
*
ctx
)
{
if
(
!
GpuArray_IS_C_CONTIGUOUS
(
&
x
->
ga
)
||
!
GpuArray_IS_C_CONTIGUOUS
(
&
z
->
ga
)
||
!
GpuArray_IS_C_CONTIGUOUS
(
&
gz
->
ga
))
{
PyErr_Format
(
PyExc_ValueError
,
"GpuMaxPoolGrad: requires data to be C-contiguous"
);
return
1
;
}
size_t
ndims
=
PyArray_DIM
(
ws
,
0
);
if
(
PyGpuArray_NDIM
(
x
)
!=
ndims
+
2
||
PyGpuArray_NDIM
(
z
)
!=
ndims
+
2
||
PyGpuArray_NDIM
(
gz
)
!=
ndims
+
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"GpuMaxPoolGrad: rank error"
);
return
1
;
}
if
(
theano_prep_output
(
gx
,
PyGpuArray_NDIM
(
x
),
PyGpuArray_DIMS
(
x
),
x
->
ga
.
typecode
,
GA_C_ORDER
,
ctx
)
!=
0
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMaxPoolGrad: failed to allocate memory"
);
return
1
;
}
{
// scope for running kernel
size_t
w
[
3
];
size_t
s
[
3
];
size_t
p
[
3
];
for
(
int
i
=
0
;
i
<
ndims
;
i
++
)
{
w
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
ws
,
i
));
s
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
stride
,
i
));
p
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
pad
,
i
));
}
size_t
max_threads_dim
;
int
err
;
const
size_t
*
z_dims
=
PyGpuArray_DIMS
(
z
);
const
size_t
*
x_dims
=
PyGpuArray_DIMS
(
x
);
// Get the max threads per blocks
err
=
gpucontext_property
(
ctx
->
ctx
,
GA_CTX_PROP_MAXLSIZE0
,
&
max_threads_dim
);
if
(
err
!=
GA_NO_ERROR
){
PyErr_SetString
(
PyExc_RuntimeError
,
"Could not fetch max_threads_dims"
);
return
1
;
}
size_t
threads_per_block
=
max_threads_dim
;
if
(
ndims
==
2
)
{
size_t
num_kernels
=
x_dims
[
0
]
*
x_dims
[
1
]
*
x_dims
[
2
]
*
x_dims
[
3
];
size_t
n_blocks
=
GET_BLOCKS
(
num_kernels
,
threads_per_block
);
err
=
max_pool2d_grad_kernel_call
(
1
,
&
n_blocks
,
&
threads_per_block
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
z_dims
[
2
],
z_dims
[
3
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gz
->
ga
.
data
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
gx
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolGrad: max_pool2d_grad_kernel %s."
,
GpuKernel_error
(
&
k_max_pool2d_grad_kernel
,
err
));
return
1
;
}
}
else
if
(
ndims
==
3
)
{
size_t
num_kernels
=
x_dims
[
0
]
*
x_dims
[
1
]
*
x_dims
[
2
]
*
x_dims
[
3
]
*
x_dims
[
4
];
size_t
n_blocks
=
GET_BLOCKS
(
num_kernels
,
threads_per_block
);
err
=
max_pool3d_grad_kernel_call
(
1
,
&
n_blocks
,
&
threads_per_block
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gz
->
ga
.
data
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
gx
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolGrad: max_pool3d_grad_kernel %s."
,
GpuKernel_error
(
&
k_max_pool3d_grad_kernel
,
err
));
return
1
;
}
}
}
return
0
;
}
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论