Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
21d76325
提交
21d76325
authored
6月 14, 2010
作者:
Simon Lemieux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adding gpu code for neighbours.py
上级
946a8fa5
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
208 行增加
和
3 行删除
+208
-3
neighbours.py
theano/sandbox/neighbours.py
+205
-3
test_neighbours.py
theano/sandbox/test_neighbours.py
+3
-0
没有找到文件。
theano/sandbox/neighbours.py
浏览文件 @
21d76325
...
@@ -110,18 +110,25 @@ class Images2Neibs(Op):
...
@@ -110,18 +110,25 @@ class Images2Neibs(Op):
for (int s = 0; s < nb_stack; s++) // loop over stacks
for (int s = 0; s < nb_stack; s++) // loop over stacks
for (int a = 0; a < grid_c; a++) // loop over height/c
for (int a = 0; a < grid_c; a++) // loop over height/c
for (int b = 0; b < grid_d; b++) // loop over width/d
for (int b = 0; b < grid_d; b++) // loop over width/d
{
int z_row = b + grid_d*(a + grid_c*(s + nb_stack*n));
for (int i = 0; i < c; i++) // loop over c
for (int i = 0; i < c; i++) // loop over c
for (int j = 0; j < d; j++) // loop over d
{
{
int ten4_2 = i + a * c;
int ten4_2 = i + a * c;
for (int j = 0; j < d; j++) // loop over d
{
int ten4_3 = j + b * d;
int ten4_3 = j + b * d;
int z_row = b + grid_d*(a + grid_c*(s + nb_stack*n));
int z_col = j + d * i;
int z_col = j + d * i;
//printf("
\\
n(
%%
i,
%%
i,
%%
i,
%%
i) --> (
%%
i,
%%
i)",n,s, ten4_2, ten4_3, z_row, z_col);
dtype_
%(z)
s* curr_z = (dtype_
%(z)
s*) PyArray_GETPTR2(
%(z)
s, z_row, z_col);
dtype_
%(z)
s* curr_z = (dtype_
%(z)
s*) PyArray_GETPTR2(
%(z)
s, z_row, z_col);
*curr_z = *( (dtype_
%(ten4)
s*) PyArray_GETPTR4(
%(ten4)
s, n, s, ten4_2, ten4_3));
*curr_z = *( (dtype_
%(ten4)
s*) PyArray_GETPTR4(
%(ten4)
s, n, s, ten4_2, ten4_3));
//printf("
\\
n(
%%
i,
%%
i,
%%
i,
%%
i) --> (
%%
i,
%%
i)",n,s, ten4_2, ten4_3, z_row, z_col);
//printf("
%%
f ", *curr_z);
//printf("
%%
f ", *curr_z);
}
}
}
}
} // END NESTED SCOPE
} // END NESTED SCOPE
"""
%
locals
()
"""
%
locals
()
images2neibs
=
Images2Neibs
()
images2neibs
=
Images2Neibs
()
...
@@ -142,3 +149,197 @@ def neibs2images(neibs, neib_shape, original_shape):
...
@@ -142,3 +149,197 @@ def neibs2images(neibs, neib_shape, original_shape):
new_neib_shape
=
T
.
stack
(
original_shape
[
-
1
]
/
neib_shape
[
1
],
neib_shape
[
1
]
)
new_neib_shape
=
T
.
stack
(
original_shape
[
-
1
]
/
neib_shape
[
1
],
neib_shape
[
1
]
)
return
images2neibs
(
neibs
.
dimshuffle
(
'x'
,
'x'
,
0
,
1
),
new_neib_shape
)
.
reshape
(
original_shape
)
return
images2neibs
(
neibs
.
dimshuffle
(
'x'
,
'x'
,
0
,
1
),
new_neib_shape
)
.
reshape
(
original_shape
)
# This is work in progress
class
GpuImages2Neibs
(
Images2Neibs
):
def
make_node
(
self
,
ten4
,
neib_shape
):
assert
ten4
.
dtype
==
'float32'
assert
neib_shape
.
dtype
==
'float32'
if
not
isinstance
(
ten4
.
type
,
CudaNdarrayType
):
raise
TypeError
(
'pvals must be cudandarray'
,
ten4
)
if
not
isinstance
(
neib_shape
.
type
,
CudaNdarrayType
):
raise
TypeError
(
'unis must be cudandarray'
,
neib_shape
)
return
Apply
(
self
,
[
ten4
,
neib_shape
],
[
CudaNdarrayType
(
broadcastable
=
(
false
,)
*
2
)()])
def
c_code_cache_version
(
self
):
return
()
#return (1,)
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
"""
static __global__ void k_multi_warp_
%(nodename)
s(
const int nb_batch,
const int nb_stack,
const int height,
const int width,
const int c,
const int d,
const int grid_c,
const int grid_d,
float * global_ten4,
float * global_out
)
{
int n = 32*blockIdx.x + threadIdx.x;
if (n < nb_batch)
for (int s = 0; s < nb_stack; s++) // loop over stacks
for (int a = 0; a < grid_c; a++) // loop over height/c
for (int b = 0; b < grid_d; b++) // loop over width/d
{
int z_row = b + grid_d*(a + grid_c*(s + nb_stack*n));
for (int i = 0; i < c; i++) // loop over c
{
int ten4_2 = i + a * c;
for (int j = 0; j < d; j++) // loop over d
{
int ten4_3 = j + b * d;
int ten4_idx = ten4_3 + width*(ten4_2 + height*(s +nb_stack*n));
int z_col = j + d * i;
int z_idx = z_col + c*d*z_row;
global_out[z_idx] = global_ten4[ten4_idx];
}
}
}
}
"""
%
locals
()
def
c_code
(
self
,
node
,
name
,
(
ten4
,
neib_shape
),
(
z
,),
sub
):
fail
=
sub
[
'fail'
]
return
"""
{
if (
%(ten4)
s->nd != 4)
{
PyErr_Format(PyExc_TypeError, "pvals wrong rank");
%(fail)
s;
}
if (
%(neib_shape)
s->nd != 1)
{
PyErr_Format(PyExc_TypeError, "unis wrong rank");
%(fail)
s;
}
if (CudaNdarray_HOST_DIMS(
%(neib_shape)
s)[0] != 2)
{
PyErr_Format(PyExc_ValueError, "neib_shape has to contain two elements");
%(fail)
s;
}
if (!CudaNdarray_is_c_contiguous(
%(neib_shape)
s))
{
PyErr_Format(PyExc_NotImplementedError, "require unis to be contiguous");
%(fail)
s;
}
if (!CudaNdarray_is_c_contiguous(
%(ten4)
s))
{
PyErr_Format(PyExc_NotImplementedError, "require ten4 to be contiguous");
%(fail)
s;
}
const float * cd = CudaNdarray_DEV_DATA(
%(neib_shape)
s);
const int c = (int) cd[0];
const int d = (int) cd[1];
if ( CudaNdarray_HOST_DIMS(
%(ten4)
s)[2]
%%
c != 0)
{
PyErr_Format(PyExc_TypeError, "neib_shape[0] must divide ten4.shape[2]");
%(fail)
s;
}
if ( CudaNdarray_HOST_DIMS(
%(ten4)
s)[3]
%%
d != 0)
{
PyErr_Format(PyExc_TypeError, "neib_shape[1] must divide ten4.shape[3]");
%(fail)
s;
}
// new dimensions for z
const int z_dim1 = c * d;
const int z_dim0 = CudaNdarray_HOST_DIMS(
%(ten4)
s)[2] / c
* CudaNdarray_HOST_DIMS(
%(ten4)
s)[3] / d
* CudaNdarray_HOST_DIMS(
%(ten4)
s)[1]
* CudaNdarray_HOST_DIMS(
%(ten4)
s)[0];
if ((NULL ==
%(z)
s)
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[0] != z_dim0)
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[1] != z_dim1))
{
Py_XDECREF(
%(z)
s);
npy_intp dims[2];
dims[0] = z_dim0;
dims[1] = z_dim1;
%(z)
s = (CudaNdarray*)CudaNdarray_NewDims(2, dims);
if (!
%(z)
s)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc z output");
%(fail)
s;
}
}
}
{ // NESTED SCOPE
const int nb_batch = CudaNdarray_HOST_DIMS(
%(ten4)
s)[0];
const int nb_stack = CudaNdarray_HOST_DIMS(
%(ten4)
s)[1];
const int height = CudaNdarray_HOST_DIMS(
%(ten4)
s)[2];
const int width = CudaNdarray_HOST_DIMS(
%(ten4)
s)[3];
// (c,d) = neib_shape
const float * cd = CudaNdarray_DEV_DATA(
%(neib_shape)
s);
const int c = (int) cd[0];
const int d = (int) cd[1];
const int grid_c = height/c;
const int grid_d = width/d;
int nb_block;
if (nb_batch
%% 32
== 0)
nb_block = nb_batch/32;
else
nb_block = (int)((float)nb_batch/32. + 1.);
dim3 n_blocks(nb_block,1,1);
dim3 n_threads(32,1,1);
int n_shared = 0;
k_multi_warp_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
nb_batch,
nb_stack,
height, width,
c, d,
grid_c, grid_d,
CudaNdarray_DEV_DATA(
%(ten4)
s),
CudaNdarray_DEV_DATA(
%(z)
s)
);
CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i; shared:
%%
i)
\\
n",
"k_multi_warp_
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z,
n_shared);
%(fail)
s;
}
} // END NESTED SCOPE
"""
%
locals
()
gpu_images2neibs
=
GpuImages2Neibs
()
@local_optimizer
()
def
use_gpu_images2neibs
(
node
):
if
node
.
op
==
images2neibs
:
return
[
host_from_gpu
(
gpu_images2neibs
(
*
[
gpu_from_host
(
i
)
for
i
in
node
.
inputs
]))]
if
theano
.
config
.
device
.
startswith
(
'gpu'
):
register_specialize
(
use_gpu_images2neibs
)
\ No newline at end of file
theano/sandbox/test_neighbours.py
浏览文件 @
21d76325
...
@@ -18,3 +18,5 @@ def neibs_test():
...
@@ -18,3 +18,5 @@ def neibs_test():
print
g
()
print
g
()
assert
allclose
(
images
.
value
,
g
())
assert
allclose
(
images
.
value
,
g
())
neibs_test
()
\ No newline at end of file
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论