Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e5ba1b08
提交
e5ba1b08
authored
6月 13, 2017
作者:
Frédéric Bastien
提交者:
GitHub
6月 13, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6012 from abergeron/fix_offset
Fix offset problems in the new backend.
上级
8dcc5fc6
a762b617
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
173 行增加
和
173 行删除
+173
-173
basic_ops.py
theano/gpuarray/basic_ops.py
+8
-4
blas.py
theano/gpuarray/blas.py
+2
-2
corr3d_gemm.c
theano/gpuarray/corr3d_gemm.c
+0
-0
corr_gemm.c
theano/gpuarray/corr_gemm.c
+0
-0
extra_ops.py
theano/gpuarray/extra_ops.py
+11
-53
multinomial.py
theano/gpuarray/multinomial.py
+26
-48
pool.c
theano/gpuarray/pool.c
+32
-20
pool_ave_grad.c
theano/gpuarray/pool_ave_grad.c
+19
-10
pool_grad_grad.c
theano/gpuarray/pool_grad_grad.c
+22
-10
pool_max_grad.c
theano/gpuarray/pool_max_grad.c
+22
-10
pool_max_rop.c
theano/gpuarray/pool_max_rop.c
+19
-10
rng_mrg.py
theano/gpuarray/rng_mrg.py
+8
-3
tstgpueye.c
theano/gpuarray/tests/tstgpueye.c
+4
-3
没有找到文件。
theano/gpuarray/basic_ops.py
浏览文件 @
e5ba1b08
...
...
@@ -1630,7 +1630,9 @@ class GpuEye(GpuKernelBase, Op):
def
gpu_kernels
(
self
,
node
,
name
):
code
=
"""
KERNEL void eye(GLOBAL_MEM
%(ctype)
s *a, ga_size n, ga_size m, ga_ssize k) {
KERNEL void eye(GLOBAL_MEM
%(ctype)
s *a, ga_size a_off,
ga_size n, ga_size m, ga_ssize k) {
a = (GLOBAL_MEM
%(ctype)
s *)(((char *)a) + a_off);
ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff);
...
...
@@ -1641,7 +1643,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
name
=
name
,
write_a
=
write_w
(
self
.
dtype
))
return
[
Kernel
(
code
=
code
,
name
=
"eye"
,
params
=
[
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
],
params
=
[
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
],
flags
=
Kernel
.
get_flags
(
self
.
dtype
),
objvar
=
'k_eye_'
+
name
)]
...
...
@@ -1685,7 +1688,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
col_off = (size_t) (k > 0?k:0);
row_off = (size_t) (k < 0?-k:0);
if (row_off < dims[0] && col_off < dims[1]) {
err = eye_call(1, &gs, &ls, 0,
%(z)
s->ga.data, dims[0], dims[1], k);
err = eye_call(1, &gs, &ls, 0,
%(z)
s->ga.data,
%(z)
s->ga.offset,
dims[0], dims[1], k);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye:
%%
s. n
%%
lu, m=
%%
lu.",
...
...
@@ -1702,4 +1706,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
return
s
def
c_code_cache_version
(
self
):
return
(
7
,)
return
(
8
,)
theano/gpuarray/blas.py
浏览文件 @
e5ba1b08
...
...
@@ -552,8 +552,8 @@ class BaseGpuCorrMM(CGpuKernelBase):
return
[
os
.
path
.
dirname
(
__file__
)]
def
c_code_cache_version
(
self
):
# Raise this whenever modifying the
code below
.
return
(
7
,)
# Raise this whenever modifying the
C code (including the file)
.
return
(
8
,)
def
c_code_helper
(
self
,
bottom
,
weights
,
top
,
direction
,
sub
,
height
=
None
,
width
=
None
):
"""
...
...
theano/gpuarray/corr3d_gemm.c
浏览文件 @
e5ba1b08
差异被折叠。
点击展开。
theano/gpuarray/corr_gemm.c
浏览文件 @
e5ba1b08
差异被折叠。
点击展开。
theano/gpuarray/extra_ops.py
浏览文件 @
e5ba1b08
...
...
@@ -35,7 +35,7 @@ class GpuCumOp(GpuKernelBase, Op):
return
hash
(
self
.
axis
)
^
hash
(
self
.
mode
)
def
c_code_cache_version
(
self
):
return
(
5
,)
return
(
6
,)
def
c_headers
(
self
):
return
[
'<numpy_compat.h>'
,
'<gpuarray/types.h>'
,
'<gpuarray_helper.h>'
]
...
...
@@ -69,11 +69,9 @@ class GpuCumOp(GpuKernelBase, Op):
code
=
"""
KERNEL void
%(kname)
s(float* input, ga_size input_offset,
float* output, ga_size output_offset,
ga_ssize inputStrides_x,
ga_ssize inputStrides_y,
ga_ssize inputStrides_z,
ga_ssize outputStrides_x, ga_ssize outputStrides_y,
ga_ssize outputStrides_z, const int offsetY, const int offsetZ,
ga_ssize inputStrides_x, ga_ssize inputStrides_y, ga_ssize inputStrides_z,
ga_ssize outputStrides_x, ga_ssize outputStrides_y, ga_ssize outputStrides_z,
const int offsetY, const int offsetZ,
const int beforeLastElementIdx, const int lastElementIdx){
input = (float *)(((char *)input) + input_offset);
output = (float *)(((char *)output) + output_offset);
...
...
@@ -216,6 +214,7 @@ class GpuCumOp(GpuKernelBase, Op):
output = (float *)(((char *)output) + output_offset);
blockSum = (float *)(((char *)blockSum) + blockSum_offset);
int globalThreadID = (blockIdx.x + 1) * blockDim.x + threadIdx.x;
// Check if current has data to process.
...
...
@@ -397,23 +396,8 @@ class GpuCumOp(GpuKernelBase, Op):
size_t dimGrid[3] = {dimGridX, localDimGridY, localDimGridZ};
size_t dimBlock[3] = {dimBlockX, 1, 1}; // One cum op per block.
size_t sharedBytes = (2*dimBlockX) * sizeof(float);
void* kernel_params[] = {(void*) input->ga.data,
(void*) &(input->ga.offset),
(void*) output->ga.data,
(void*) &(output->ga.offset),
(void*) &nbElementsPerCumOp,
(void*) &inputStrides_x,
(void*) &inputStrides_y,
(void*) &inputStrides_z,
(void*) &outputStrides_x,
(void*) &outputStrides_y,
(void*) &outputStrides_z,
(void*) &offsetY,
(void*) &offsetZ,
(void*) deviceBlockSum->ga.data,
(void*) &(deviceBlockSum->ga.offset)
};
int err = GpuKernel_call(&k_blockCumOp_
%(nodename)
s, 3, dimGrid, dimBlock, sharedBytes, kernel_params);
int err = k_blockCumOp_call(3, dimGrid, dimBlock, sharedBytes, input->ga.data, input->ga.offset, output->ga.data, output->ga.offset, nbElementsPerCumOp, inputStrides_x, inputStrides_y, inputStrides_z, outputStrides_x, outputStrides_y, outputStrides_z, offsetY, offsetZ, deviceBlockSum->ga.data, deviceBlockSum->ga.offset);
if (err != GA_NO_ERROR){
PyErr_SetString(PyExc_RuntimeError, "blockCumOp call failed");
return -1;
...
...
@@ -429,18 +413,8 @@ class GpuCumOp(GpuKernelBase, Op):
// report partial cum ops of previous blocks to subsequents ones.
size_t dimGrid[3] = {dimGridX, localDimGridY, localDimGridZ};
size_t dimBlock[3] = {dimBlockX, 1, 1};
void* kernel_params[] = {(void*) output->ga.data,
(void*) &(output->ga.offset),
(void*) deviceBlockSum->ga.data,
(void*) &(deviceBlockSum->ga.offset),
(void*) &nbElementsPerCumOp,
(void*) &outputStrides_x,
(void*) &outputStrides_y,
(void*) &outputStrides_z,
(void*) &offsetY,
(void*) &offsetZ
};
int err = GpuKernel_call(&k_finalCumOp_
%(nodename)
s, 3, dimGrid, dimBlock, sharedBytes, kernel_params);
int err = k_finalCumOp_call(3, dimGrid, dimBlock, sharedBytes, output->ga.data, output->ga.offset, deviceBlockSum->ga.data, deviceBlockSum->ga.offset, nbElementsPerCumOp, outputStrides_x, outputStrides_y, outputStrides_z, offsetY, offsetZ);
if (err != GA_NO_ERROR){
PyErr_SetString(PyExc_RuntimeError, "finalCumOp call failed");
return -1;
...
...
@@ -450,24 +424,8 @@ class GpuCumOp(GpuKernelBase, Op):
if (shape[axis] != nbElementsPerCumOp){
size_t dimGrid[3] = {1, localDimGridY, localDimGridZ};
size_t dimBlock[3] = {1, 1, 1};
size_t tmp0 = shape[axis]-2;
size_t tmp1 = shape[axis]-1;
void* kernel_params[] = {(void*) input->ga.data,
(void*) &(input->ga.offset),
(void*) output->ga.data,
(void*) &(output->ga.offset),
(void*) &inputStrides_x,
(void*) &inputStrides_y,
(void*) &inputStrides_z,
(void*) &outputStrides_x,
(void*) &outputStrides_y,
(void*) &outputStrides_z,
(void*) &offsetY,
(void*) &offsetZ,
(void*) &(tmp0),
(void*) &(tmp1)
};
int err = GpuKernel_call(&k_cumadd_
%(nodename)
s, 3, dimGrid, dimBlock, sharedBytes, kernel_params);
int err = k_cumadd_call(3, dimGrid, dimBlock, sharedBytes, input->ga.data, input->ga.offset, output->ga.data, output->ga.offset, inputStrides_x, inputStrides_y, inputStrides_z, outputStrides_x, outputStrides_y, outputStrides_z, offsetY, offsetZ, shape[axis] - 2, shape[axis] - 1);
if (err != GA_NO_ERROR){
PyErr_SetString(PyExc_RuntimeError, "cumadd call failed");
return -1;
...
...
theano/gpuarray/multinomial.py
浏览文件 @
e5ba1b08
...
...
@@ -71,16 +71,22 @@ class GPUAMultinomialFromUniform(GpuKernelBase, Op):
KERNEL void k_multi_warp_multinomial(
const ga_size nb_multi,
const ga_size nb_outcomes,
GLOBAL_MEM
%(in_ctype)
s * global_pvals,
GLOBAL_MEM
%(in_ctype)
s *global_pvals,
const ga_size global_pvals_offset,
const ga_ssize pvals_row_stride,
const ga_ssize pvals_col_stride,
GLOBAL_MEM
%(in_ctype)
s * global_unis,
GLOBAL_MEM
%(in_ctype)
s *global_unis,
const ga_size global_unis_offset,
const ga_ssize unis_stride,
GLOBAL_MEM
%(out_ctype)
s * global_outs,
GLOBAL_MEM
%(out_ctype)
s *global_outs,
const ga_size global_outs_offset,
const ga_ssize outs_row_stride,
const ga_ssize outs_col_stride
)
{
global_pvals = (GLOBAL_MEM
%(in_ctype)
s *)(((char *)global_pvals) + global_pvals_offset);
global_unis = (GLOBAL_MEM
%(in_ctype)
s *)(((char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM
%(out_ctype)
s *)(((char *)global_outs) + global_outs_offset);
// each thread takes care of one multinomial draw
int n = LDIM_0*GID_0 + LID_0;
if (n < nb_multi)
...
...
@@ -113,11 +119,14 @@ KERNEL void k_multi_warp_multinomial(
params
=
[
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
GpuArray
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
GpuArray
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
GpuArray
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
SSIZE
],
flags
=
Kernel
.
get_flags
(
node
.
outputs
[
0
]
.
dtype
),
...
...
@@ -193,27 +202,8 @@ KERNEL void k_multi_warp_multinomial(
assert(nb_blocks*nb_threads >= nb_multi);
void *args[10];
ssize_t strides[5] = {
PyGpuArray_STRIDES(pvals)[0]/gpuarray_get_elsize(
%(in_typecode)
s),
PyGpuArray_STRIDES(pvals)[1]/gpuarray_get_elsize(
%(in_typecode)
s),
PyGpuArray_STRIDES(unis)[0]/gpuarray_get_elsize(
%(in_typecode)
s),
PyGpuArray_STRIDES(out)[0]/gpuarray_get_elsize(
%(out_typecode)
s),
PyGpuArray_STRIDES(out)[1]/gpuarray_get_elsize(
%(out_typecode)
s)
};
int err;
args[0] = (void*)&PyGpuArray_DIMS(out)[1];
args[1] = (void*)&PyGpuArray_DIMS(out)[0];
args[2] = pvals->ga.data; //PyGpuArray_DEV_DATA(pvals);
args[3] = (void*)&strides[0];
args[4] = (void*)&strides[1];
args[5] = unis->ga.data; //PyGpuArray_DEV_DATA(unis);
args[6] = (void*)&strides[2];
args[7] = out->ga.data; //PyGpuArray_DEV_DATA(out);
args[8] = (void*)&strides[3];
args[9] = (void*)&strides[4];
err = GpuKernel_call(&
%(kname)
s, 1, &nb_blocks, &nb_threads, 0, args);
int err = k_multi_warp_multinomial_call(1, &nb_blocks, &nb_threads, 0, PyGpuArray_DIMS(out)[1], PyGpuArray_DIMS(out)[0], pvals->ga.data, pvals->ga.offset, PyGpuArray_STRIDES(pvals)[0]/gpuarray_get_elsize(
%(in_typecode)
s), PyGpuArray_STRIDES(pvals)[1]/gpuarray_get_elsize(
%(in_typecode)
s), unis->ga.data, unis->ga.offset, PyGpuArray_STRIDES(unis)[0]/gpuarray_get_elsize(
%(in_typecode)
s), out->ga.data, out->ga.offset, PyGpuArray_STRIDES(out)[0]/gpuarray_get_elsize(
%(out_typecode)
s), PyGpuArray_STRIDES(out)[1]/gpuarray_get_elsize(
%(out_typecode)
s));
if (err != GA_NO_ERROR) {
PyErr_Format(
PyExc_RuntimeError,
...
...
@@ -230,7 +220,7 @@ KERNEL void k_multi_warp_multinomial(
return
s
def
c_code_cache_version
(
self
):
return
(
3
,)
return
(
4
,)
class
GPUAChoiceFromUniform
(
GpuKernelBase
,
Op
):
...
...
@@ -295,15 +285,21 @@ KERNEL void k_multi_warp_multinomial_wor(
const ga_size nb_outcomes,
const ga_size n_samples,
GLOBAL_MEM float * global_pvals_copy,
const ga_size global_pvals_offset,
const ga_ssize pvals_row_stride,
const ga_ssize pvals_col_stride,
GLOBAL_MEM float * global_unis,
const ga_size global_unis_offset,
const ga_ssize unis_stride,
GLOBAL_MEM ga_long * global_outs,
const ga_size global_outs_offset,
const ga_ssize outs_row_stride,
const ga_ssize outs_col_stride
)
{
global_pvals_copy = (GLOBAL_MEM float *)(((char *)global_pvals_copy) + global_pvals_offset);
global_unis = (GLOBAL_MEM float *)(((char *)global_unis) + global_unis_offset);
global_outs = (GLOBAL_MEM ga_long *)(((char *)global_outs) + global_outs_offset);
// each thread takes care of one multinomial-wor n_samples-draw
int n = LDIM_0*GID_0 + LID_0;
...
...
@@ -344,11 +340,14 @@ KERNEL void k_multi_warp_multinomial_wor(
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
GpuArray
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
GpuArray
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
GpuArray
,
pygpu
.
gpuarray
.
SIZE
,
pygpu
.
gpuarray
.
SSIZE
,
pygpu
.
gpuarray
.
SSIZE
],
...
...
@@ -438,28 +437,7 @@ KERNEL void k_multi_warp_multinomial_wor(
assert(nb_blocks*nb_threads >= nb_multi);
void *args[11];
ssize_t strides[5] = {
PyGpuArray_STRIDES(pvals)[0]/sizeof(float),
PyGpuArray_STRIDES(pvals)[1]/sizeof(float),
PyGpuArray_STRIDES(unis)[0]/sizeof(float),
PyGpuArray_STRIDES(out)[0]/8,
PyGpuArray_STRIDES(out)[1]/8
};
int err;
args[0] = (void*)&PyGpuArray_DIMS(pvals)[0];
args[1] = (void*)&PyGpuArray_DIMS(pvals)[1];
args[2] = (void*)&n_samples;
args[3] = pvals_copy->ga.data; //PyGpuArray_DEV_DATA(pvals);
args[4] = (void*)&strides[0];
args[5] = (void*)&strides[1];
args[6] = unis->ga.data; //PyGpuArray_DEV_DATA(unis);
args[7] = (void*)&strides[2];
args[8] = out->ga.data; //PyGpuArray_DEV_DATA(out);
args[9] = (void*)&strides[3];
args[10] = (void*)&strides[4];
err = GpuKernel_call(&
%(kname)
s, 1, &nb_blocks, &nb_threads, 0, args);
int err = k_multi_warp_multinomial_wor_call(1, &nb_blocks, &nb_threads, 0, PyGpuArray_DIMS(pvals)[0], PyGpuArray_DIMS(pvals)[1], n_samples, pvals_copy->ga.data, pvals_copy->ga.offset, PyGpuArray_STRIDES(pvals)[0]/sizeof(float), PyGpuArray_STRIDES(pvals)[1]/sizeof(float), unis->ga.data, unis->ga.offset, PyGpuArray_STRIDES(unis)[0]/sizeof(float), out->ga.data, out->ga.offset, PyGpuArray_STRIDES(out)[0]/8, PyGpuArray_STRIDES(out)[1]/8);
if (err != GA_NO_ERROR) {
PyErr_Format(
PyExc_RuntimeError,
...
...
@@ -477,7 +455,7 @@ KERNEL void k_multi_warp_multinomial_wor(
return
s
def
c_code_cache_version
(
self
):
return
(
7
,)
return
(
8
,)
@register_opt
(
'fast_compile'
)
...
...
theano/gpuarray/pool.c
浏览文件 @
e5ba1b08
#section kernels
#kernel max_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size,
*
:
#kernel max_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size,
size, *, size
:
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool2d_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
,
const
ga_size
z_off
)
{
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
z
)
+
z_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
...
...
@@ -41,18 +43,20 @@ KERNEL void max_pool2d_kernel(const ga_size nthreads,
}
}
#kernel max_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size,
*
:
#kernel max_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size,
size, *, size
:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool3d_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
,
const
ga_size
z_off
)
{
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
z
)
+
z_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
...
...
@@ -90,17 +94,19 @@ KERNEL void max_pool3d_kernel(const ga_size nthreads,
}
}
#kernel ave_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size,
*
:
#kernel ave_pool2d_kernel : size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size,
size, *, size
:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
ave_pool2d_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
const
ga_bool
inc_pad
,
const
ga_bool
sum_mode
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
,
const
ga_size
z_off
)
{
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
z
)
+
z_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
...
...
@@ -143,20 +149,22 @@ KERNEL void ave_pool2d_kernel(const ga_size nthreads,
}
}
#kernel ave_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, size, size,
*
:
#kernel ave_pool3d_kernel : size, size, size, size, size, size, size, size, size, *, size, size, size, size, size, size, size, size, size, size, size,
size, *, size
:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
ave_pool3d_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
const
ga_bool
inc_pad
,
const
ga_bool
sum_mode
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
,
const
ga_size
z_off
)
{
// grid stride looping
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
z
)
+
z_off
);
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
...
...
@@ -273,8 +281,8 @@ int APPLY_SPECIFIC(pool)(PyGpuArrayObject *x,
err
=
max_pool2d_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
x_dims
[
2
],
x_dims
[
3
],
x
->
ga
.
data
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
z
)
->
ga
.
data
);
x
->
ga
.
data
,
x
->
ga
.
offset
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
z
)
->
ga
.
data
,
(
*
z
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPool: max_pool2d_kernel %s."
,
...
...
@@ -285,8 +293,10 @@ int APPLY_SPECIFIC(pool)(PyGpuArrayObject *x,
err
=
ave_pool2d_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
x_dims
[
2
],
x_dims
[
3
],
x
->
ga
.
data
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
INC_PAD
,
SUM_MODE
,
(
*
z
)
->
ga
.
data
);
x
->
ga
.
data
,
x
->
ga
.
offset
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
INC_PAD
,
SUM_MODE
,
(
*
z
)
->
ga
.
data
,
(
*
z
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPool: ave_pool2d_kernel %s."
,
...
...
@@ -301,8 +311,8 @@ int APPLY_SPECIFIC(pool)(PyGpuArrayObject *x,
err
=
max_pool3d_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
x
->
ga
.
data
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
z
)
->
ga
.
data
);
x
->
ga
.
data
,
x
->
ga
.
offset
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
z
)
->
ga
.
data
,
(
*
z
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPool: max_pool3d_kernel %s."
,
...
...
@@ -313,9 +323,11 @@ int APPLY_SPECIFIC(pool)(PyGpuArrayObject *x,
err
=
ave_pool3d_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
x
->
ga
.
data
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
x
->
ga
.
data
,
x
->
ga
.
offset
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
INC_PAD
,
SUM_MODE
,
(
*
z
)
->
ga
.
data
);
INC_PAD
,
SUM_MODE
,
(
*
z
)
->
ga
.
data
,
(
*
z
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPool: ave_pool3d_kernel %s."
,
...
...
theano/gpuarray/pool_ave_grad.c
浏览文件 @
e5ba1b08
#section kernels
#kernel ave_pool2d_grad_kernel : size, size, size, size, size, size, size, *,
*, size, size, size, size, size, size, size, size, *
:
#kernel ave_pool2d_grad_kernel : size, size, size, size, size, size, size, *,
size, *, size, size, size, size, size, size, size, size, size, *, size
:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
ave_pool2d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
gz
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
gz
,
const
ga_size
gz_off
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
const
ga_bool
inc_pad
,
const
ga_bool
sum_mode
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
,
const
ga_size
gx_off
)
{
x
=
(
GLOBAL_MEM
const
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
gz
=
(
GLOBAL_MEM
const
DTYPE_INPUT_1
*
)(((
char
*
)
gz
)
+
gz_off
);
gx
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
gx
)
+
gx_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
...
...
@@ -46,19 +49,22 @@ KERNEL void ave_pool2d_grad_kernel(const ga_size nthreads,
}
}
#kernel ave_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *,
*, size, size, size, size, size, size, size, size, size, size, size, *
:
#kernel ave_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *,
size, *, size, size, size, size, size, size, size, size, size, size, size, size, *, size
:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
ave_pool3d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
gz
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
gz
,
const
ga_size
gz_off
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
const
ga_bool
inc_pad
,
const
ga_bool
sum_mode
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
)
const
ga_bool
inc_pad
,
const
ga_bool
sum_mode
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
,
const
ga_size
gx_off
)
{
x
=
(
GLOBAL_MEM
const
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
gz
=
(
GLOBAL_MEM
const
DTYPE_INPUT_1
*
)(((
char
*
)
gz
)
+
gz_off
);
gx
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
gx
)
+
gx_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
...
...
@@ -152,9 +158,11 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x,
err
=
ave_pool2d_grad_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
z_dims
[
2
],
z_dims
[
3
],
x
->
ga
.
data
,
gz
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
gz
->
ga
.
data
,
gz
->
ga
.
offset
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
INC_PAD
,
SUM_MODE
,
(
*
gx
)
->
ga
.
data
);
INC_PAD
,
SUM_MODE
,
(
*
gx
)
->
ga
.
data
,
(
*
gx
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuAveragePoolGrad: ave_pool2d_grad_kernel %s."
,
...
...
@@ -166,10 +174,11 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x,
err
=
ave_pool3d_grad_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x
->
ga
.
data
,
gz
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
gz
->
ga
.
data
,
gz
->
ga
.
offset
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
INC_PAD
,
SUM_MODE
,
(
*
gx
)
->
ga
.
data
);
(
*
gx
)
->
ga
.
data
,
(
*
gx
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuAveragePoolGrad: ave_pool3d_grad_kernel %s."
,
...
...
theano/gpuarray/pool_grad_grad.c
浏览文件 @
e5ba1b08
#section kernels
#kernel max_pool2d_grad_grad_kernel : size, size, size, size, size, size, size, *,
*, *, size, size, size, size, size, size, *
:
#kernel max_pool2d_grad_grad_kernel : size, size, size, size, size, size, size, *,
size, *, size, *, size, size, size, size, size, size, size, *, size
:
KERNEL
void
max_pool2d_grad_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gx
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
const
ga_size
z_off
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gx
,
const
ga_size
gx_off
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gz
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gz
,
const
ga_size
gz_off
)
{
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
DTYPE_INPUT_1
*
)(((
char
*
)
z
)
+
z_off
);
gx
=
(
GLOBAL_MEM
DTYPE_INPUT_2
*
)(((
char
*
)
gx
)
+
gx_off
);
gz
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
gz
)
+
gz_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
...
...
@@ -42,18 +46,22 @@ KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads,
}
}
#kernel max_pool3d_grad_grad_kernel : size, size, size, size, size, size, size, size, size, *,
*, *, size, size, size, size, size, size, size, size, size, *
:
#kernel max_pool3d_grad_grad_kernel : size, size, size, size, size, size, size, size, size, *,
size, *, size, *, size, size, size, size, size, size, size, size, size, size, *, size
:
KERNEL
void
max_pool3d_grad_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gx
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
const
ga_size
z_off
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gx
,
const
ga_size
gx_off
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gz
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gz
,
const
ga_size
gz_off
)
{
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
DTYPE_INPUT_1
*
)(((
char
*
)
z
)
+
z_off
);
gx
=
(
GLOBAL_MEM
DTYPE_INPUT_2
*
)(((
char
*
)
gx
)
+
gx_off
);
gz
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
gz
)
+
gz_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
...
...
@@ -146,9 +154,11 @@ int APPLY_SPECIFIC(pool_grad_grad)(PyGpuArrayObject *x,
err
=
max_pool2d_grad_grad_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
x_dims
[
2
],
x_dims
[
3
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gx
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
z
->
ga
.
data
,
z
->
ga
.
offset
,
gx
->
ga
.
data
,
gx
->
ga
.
offset
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
gz
)
->
ga
.
data
);
(
*
gz
)
->
ga
.
data
,
(
*
gz
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPoolingGradGrad: max_pool2d_grad_grad_kernel %s."
,
...
...
@@ -161,9 +171,11 @@ int APPLY_SPECIFIC(pool_grad_grad)(PyGpuArrayObject *x,
err
=
max_pool3d_grad_grad_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gx
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
z
->
ga
.
data
,
z
->
ga
.
offset
,
gx
->
ga
.
data
,
gx
->
ga
.
offset
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
gz
)
->
ga
.
data
);
(
*
gz
)
->
ga
.
data
,
(
*
gz
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPoolingGradGrad: max_pool3d_grad_grad_kernel %s."
,
...
...
theano/gpuarray/pool_max_grad.c
浏览文件 @
e5ba1b08
#section kernels
#kernel max_pool2d_grad_kernel : size, size, size, size, size, size, size, *,
*, *, size, size, size, size, size, size, *
:
#kernel max_pool2d_grad_kernel : size, size, size, size, size, size, size, *,
size, *, size, *, size, size, size, size, size, size, size, *, size
:
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool2d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gz
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
const
ga_size
z_off
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gz
,
const
ga_size
gz_off
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
)
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
,
const
ga_size
gx_off
)
{
x
=
(
GLOBAL_MEM
const
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
const
DTYPE_INPUT_1
*
)(((
char
*
)
z
)
+
z_off
);
gz
=
(
GLOBAL_MEM
const
DTYPE_INPUT_2
*
)(((
char
*
)
gz
)
+
gz_off
);
gx
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
gx
)
+
gx_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
...
...
@@ -38,19 +42,23 @@ KERNEL void max_pool2d_grad_kernel(const ga_size nthreads,
}
}
#kernel max_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *,
*, *, size, size, size, size, size, size, size, size, size, *
:
#kernel max_pool3d_grad_kernel : size, size, size, size, size, size, size, size, size, *,
size, *, size, *, size, size, size, size, size, size, size, size, size, size, *, size
:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool3d_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gz
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
z
,
const
ga_size
z_off
,
GLOBAL_MEM
const
DTYPE_INPUT_2
*
gz
,
const
ga_size
gz_off
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
gx
,
const
ga_size
gx_off
)
{
x
=
(
GLOBAL_MEM
const
DTYPE_INPUT_0
*
)(((
char
*
)
x
)
+
x_off
);
z
=
(
GLOBAL_MEM
const
DTYPE_INPUT_1
*
)(((
char
*
)
z
)
+
z_off
);
gz
=
(
GLOBAL_MEM
const
DTYPE_INPUT_2
*
)(((
char
*
)
gz
)
+
gz_off
);
gx
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
gx
)
+
gx_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
...
...
@@ -138,9 +146,11 @@ int APPLY_SPECIFIC(max_pool_grad)(PyGpuArrayObject *x,
err
=
max_pool2d_grad_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
z_dims
[
2
],
z_dims
[
3
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gz
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
z
->
ga
.
data
,
z
->
ga
.
offset
,
gz
->
ga
.
data
,
gz
->
ga
.
offset
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
gx
)
->
ga
.
data
);
(
*
gx
)
->
ga
.
data
,
(
*
gx
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolGrad: max_pool2d_grad_kernel %s."
,
...
...
@@ -152,9 +162,11 @@ int APPLY_SPECIFIC(max_pool_grad)(PyGpuArrayObject *x,
err
=
max_pool3d_grad_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gz
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
z
->
ga
.
data
,
z
->
ga
.
offset
,
gz
->
ga
.
data
,
gz
->
ga
.
offset
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
gx
)
->
ga
.
data
);
p
[
0
],
p
[
1
],
p
[
2
],
(
*
gx
)
->
ga
.
data
,
(
*
gx
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolGrad: max_pool3d_grad_kernel %s."
,
...
...
theano/gpuarray/pool_max_rop.c
浏览文件 @
e5ba1b08
#section kernels
#kernel max_pool2d_rop_kernel : size, size, size, size, size, size, size, *,
*, size, size, size, size, size, size, *
:
#kernel max_pool2d_rop_kernel : size, size, size, size, size, size, size, *,
size, *, size, size, size, size, size, size, size, *, size
:
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool2d_rop_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
ex
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
ex
,
const
ga_size
ex_off
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
,
const
ga_size
z_off
)
{
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
x
)(((
char
*
)
x
)
+
x_off
);
ex
=
(
GLOBAL_MEM
DTYPE_INPUT_1
*
x
)(((
char
*
)
ex
)
+
ex_off
);
z
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
x
)(((
char
*
)
z
)
+
z_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
...
...
@@ -46,19 +49,22 @@ KERNEL void max_pool2d_rop_kernel(const ga_size nthreads,
}
}
#kernel max_pool3d_rop_kernel : size, size, size, size, size, size, size, size, size, *,
*, size, size, size, size, size, size, size, size, size, *
:
#kernel max_pool3d_rop_kernel : size, size, size, size, size, size, size, size, size, *,
size, *, size, size, size, size, size, size, size, size, size, size, *, size
:
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool3d_rop_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
ex
,
GLOBAL_MEM
const
DTYPE_INPUT_0
*
x
,
const
ga_size
x_off
,
GLOBAL_MEM
const
DTYPE_INPUT_1
*
ex
,
const
ga_size
ex_off
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
)
GLOBAL_MEM
DTYPE_OUTPUT_0
*
z
,
const
ga_size
x_off
)
{
x
=
(
GLOBAL_MEM
DTYPE_INPUT_0
*
x
)(((
char
*
)
x
)
+
x_off
);
ex
=
(
GLOBAL_MEM
DTYPE_INPUT_1
*
x
)(((
char
*
)
ex
)
+
ex_off
);
z
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
x
)(((
char
*
)
z
)
+
z_off
);
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
...
...
@@ -167,9 +173,10 @@ int APPLY_SPECIFIC(max_pool_rop)(PyGpuArrayObject *x,
err
=
max_pool2d_rop_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
x_dims
[
2
],
x_dims
[
3
],
x
->
ga
.
data
,
ex
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
ex
->
ga
.
data
,
ex
->
ga
.
offset
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
z
)
->
ga
.
data
);
(
*
z
)
->
ga
.
data
,
(
*
z
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolRop: max_pool2d_rop_kernel %s."
,
...
...
@@ -182,9 +189,11 @@ int APPLY_SPECIFIC(max_pool_rop)(PyGpuArrayObject *x,
err
=
max_pool3d_rop_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
x
->
ga
.
data
,
ex
->
ga
.
data
,
x
->
ga
.
data
,
x
->
ga
.
offset
,
ex
->
ga
.
data
,
ex
->
ga
.
offset
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
z
)
->
ga
.
data
);
p
[
0
],
p
[
1
],
p
[
2
],
(
*
z
)
->
ga
.
data
,
(
*
z
)
->
ga
.
offset
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolRop: max_pool3d_rop_kernel %s."
,
...
...
theano/gpuarray/rng_mrg.py
浏览文件 @
e5ba1b08
...
...
@@ -75,10 +75,14 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
code
=
"""
KERNEL void mrg_uniform(
GLOBAL_MEM
%(otype)
s *sample_data,
ga_size sample_offset,
GLOBAL_MEM ga_int *state_data,
ga_size state_offset,
const ga_uint Nsamples,
const ga_uint Nstreams_used)
{
sample_data = (GLOBAL_MEM
%(otype)
s *)(((char *)sample_data) + sample_offset);
state_data = (GLOBAL_MEM ga_int *)(((char *)state_data) + state_offset);
/*
* The cluda backend makes sure that ga_int corresponds to
* a 32 bit signed type on the target device. It is not a
...
...
@@ -157,7 +161,8 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
from
pygpu
import
gpuarray
return
[
Kernel
(
code
=
code
,
name
=
"mrg_uniform"
,
params
=
[
gpuarray
.
GpuArray
,
gpuarray
.
GpuArray
,
params
=
[
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
'uint32'
,
'uint32'
],
flags
=
Kernel
.
get_flags
(
self
.
output_type
.
dtype
,
'int32'
))
]
...
...
@@ -273,7 +278,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
}
// Make sure we run as many blocks as we need to cover the whole n_streams
gs = (n_streams + ls - 1)/ls;
err = mrg_uniform_call(1, &ls, &gs, 0,
%(o_sample)
s->ga.data,
%(o_
rstate)
s->ga.data
, n_elements, n_streams);
err = mrg_uniform_call(1, &ls, &gs, 0,
%(o_sample)
s->ga.data,
%(o_
sample)
s->ga.offset,
%(o_rstate)
s->ga.data,
%(o_rstate)
s->ga.offset
, n_elements, n_streams);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, "mrg_uniform_call:
%%
s
\\
n",
GpuKernel_error(&
%(kname)
s, err));
...
...
@@ -283,7 +288,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
(
1
2
,)
return
(
1
3
,)
@register_opt2
([
mrg_uniform
],
'fast_compile'
)
...
...
theano/gpuarray/tests/tstgpueye.c
浏览文件 @
e5ba1b08
#section kernels
#kernel eye : *, size, size :
#kernel eye : *, size, size
, size
:
/* The eye name will be used to generate supporting objects. The only
you probably need to care about is the kernel object which will be
named 'k_' + <the name above> (k_eye in this case). This name also
has to match the kernel function name below.
*/
KERNEL
void
eye
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
a
,
ga_size
n
,
ga_size
m
)
{
KERNEL
void
eye
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
a
,
ga_size
a_off
,
ga_size
n
,
ga_size
m
)
{
a
=
(
GLOBAL_MEM
DTYPE_OUTPUT_0
*
)(((
char
*
)
a
)
+
a_off
);
ga_size
nb
=
n
<
m
?
n
:
m
;
for
(
ga_size
i
=
LID_0
;
i
<
nb
;
i
+=
LDIM_0
)
{
a
[
i
*
m
+
i
]
=
1
;
...
...
@@ -37,7 +38,7 @@ int APPLY_SPECIFIC(tstgpueye)(PyArrayObject *n, PyArrayObject *m,
ls
=
1
;
gs
=
256
;
/* The eye_call name comes from the kernel declaration above. */
err
=
eye_call
(
1
,
&
gs
,
&
ls
,
0
,
(
*
z
)
->
ga
.
data
,
dims
[
0
],
dims
[
1
]);
err
=
eye_call
(
1
,
&
gs
,
&
ls
,
0
,
(
*
z
)
->
ga
.
data
,
(
*
z
)
->
ga
.
offset
,
dims
[
0
],
dims
[
1
]);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"gpuarray error: kEye: %s. n%lu, m=%lu."
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论