Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2faeb62c
提交
2faeb62c
authored
6月 28, 2013
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1437 from nouiz/gpu_iadd
Gpu iadd for 6d tensor
上级
cd50d5ef
9c43373e
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
106 行增加
和
0 行删除
+106
-0
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+106
-0
test_cuda_ndarray.py
theano/sandbox/cuda/tests/test_cuda_ndarray.py
+0
-0
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
2faeb62c
...
@@ -1389,6 +1389,45 @@ __global__ void k_ielem_4(const int d0, const int d1, const int d2, const int d3
...
@@ -1389,6 +1389,45 @@ __global__ void k_ielem_4(const int d0, const int d1, const int d2, const int d3
}
}
}
}
template
<
int
operator_num
>
__global__
void
k_ielem_6
(
const
int
d0
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
int
d5
,
float
*
a
,
const
int
sA0
,
const
int
sA1
,
const
int
sA2
,
const
int
sA3
,
const
int
sA4
,
const
int
sA5
,
const
float
*
b
,
const
int
sB0
,
const
int
sB1
,
const
int
sB2
,
const
int
sB3
,
const
int
sB4
,
const
int
sB5
){
for
(
int
i0
=
blockIdx
.
x
;
i0
<
d0
;
i0
+=
gridDim
.
x
){
for
(
int
i1
=
blockIdx
.
y
;
i1
<
d1
;
i1
+=
gridDim
.
y
){
for
(
int
i2
=
blockIdx
.
z
;
i2
<
d2
;
i2
+=
gridDim
.
z
){
for
(
int
i3
=
threadIdx
.
x
;
i3
<
d3
;
i3
+=
blockDim
.
x
){
for
(
int
i4
=
threadIdx
.
y
;
i4
<
d4
;
i4
+=
blockDim
.
y
){
for
(
int
i5
=
threadIdx
.
z
;
i5
<
d5
;
i5
+=
blockDim
.
z
){
switch
(
operator_num
)
{
case
IADD
:
a
[
i0
*
sA0
+
i1
*
sA1
+
i2
*
sA2
+
i3
*
sA3
+
i4
*
sA4
+
i5
*
sA5
]
+=
b
[
i0
*
sB0
+
i1
*
sB1
+
i2
*
sB2
+
i3
*
sB3
+
i4
*
sB4
+
i5
*
sB5
];
break
;
case
IDIV
:
a
[
i0
*
sA0
+
i1
*
sA1
+
i2
*
sA2
+
i3
*
sA3
+
i4
*
sA4
+
i5
*
sA5
]
/=
b
[
i0
*
sB0
+
i1
*
sB1
+
i2
*
sB2
+
i3
*
sB3
+
i4
*
sB4
+
i5
*
sB5
];
break
;
case
CPY
:
a
[
i0
*
sA0
+
i1
*
sA1
+
i2
*
sA2
+
i3
*
sA3
+
i4
*
sA4
+
i5
*
sA5
]
=
b
[
i0
*
sB0
+
i1
*
sB1
+
i2
*
sB2
+
i3
*
sB3
+
i4
*
sB4
+
i5
*
sB5
];
break
;
}
}
}
}
}
}
}
}
/*
/*
CudaNdarray_inplace_elemwise
CudaNdarray_inplace_elemwise
Compute elemwise, working inplace on A.
Compute elemwise, working inplace on A.
...
@@ -1415,19 +1454,31 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
...
@@ -1415,19 +1454,31 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
const
int
,
const
int
,
const
int
,
const
int
,
const
float
*
,
const
int
,
const
int
,
const
float
*
,
const
int
,
const
int
,
const
int
,
const
int
);
const
int
,
const
int
);
void
(
*
k6
)(
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
float
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
float
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
);
switch
(
fct_nb
)
switch
(
fct_nb
)
{
{
case
IADD
:
case
IADD
:
k3
=
k_ielem_3
<
IADD
>
;
k3
=
k_ielem_3
<
IADD
>
;
k4
=
k_ielem_4
<
IADD
>
;
k4
=
k_ielem_4
<
IADD
>
;
k6
=
k_ielem_6
<
IADD
>
;
break
;
break
;
case
IDIV
:
case
IDIV
:
k3
=
k_ielem_3
<
IDIV
>
;
k3
=
k_ielem_3
<
IDIV
>
;
k4
=
k_ielem_4
<
IDIV
>
;
k4
=
k_ielem_4
<
IDIV
>
;
k6
=
k_ielem_6
<
IDIV
>
;
break
;
break
;
case
CPY
:
case
CPY
:
k3
=
k_ielem_3
<
CPY
>
;
k3
=
k_ielem_3
<
CPY
>
;
k4
=
k_ielem_4
<
CPY
>
;
k4
=
k_ielem_4
<
CPY
>
;
k6
=
k_ielem_6
<
CPY
>
;
break
;
break
;
default
:
default
:
assert
(
0
);
assert
(
0
);
...
@@ -1769,6 +1820,61 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
...
@@ -1769,6 +1820,61 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
}
}
}
}
break
;
break
;
case
6
:
{
dim3
n_blocks
(
std
::
min
(
CudaNdarray_HOST_DIMS
(
self
)[
0
],
NUM_VECTOR_OP_BLOCKS
),
CudaNdarray_HOST_DIMS
(
self
)[
1
],
CudaNdarray_HOST_DIMS
(
self
)[
2
]
);
while
(
n_blocks
.
x
*
n_blocks
.
y
>
NUM_VECTOR_OP_BLOCKS
)
n_blocks
.
y
/=
2
;
while
(
n_blocks
.
x
*
n_blocks
.
y
*
n_blocks
.
z
>
NUM_VECTOR_OP_BLOCKS
)
n_blocks
.
z
/=
2
;
dim3
n_threads
(
std
::
min
(
CudaNdarray_HOST_DIMS
(
self
)[
3
],
NUM_VECTOR_OP_THREADS_PER_BLOCK
)
//TODO: DON"T YOU NEED OT PUT DIMS[4] in here???
//TODO: DON"T YOU NEED OT PUT DIMS[5] in here???
);
k6
<<<
n_blocks
,
n_threads
>>>
(
CudaNdarray_HOST_DIMS
(
self
)[
0
],
CudaNdarray_HOST_DIMS
(
self
)[
1
],
CudaNdarray_HOST_DIMS
(
self
)[
2
],
CudaNdarray_HOST_DIMS
(
self
)[
3
],
CudaNdarray_HOST_DIMS
(
self
)[
4
],
CudaNdarray_HOST_DIMS
(
self
)[
5
],
CudaNdarray_DEV_DATA
(
self
),
CudaNdarray_HOST_STRIDES
(
self
)[
0
],
CudaNdarray_HOST_STRIDES
(
self
)[
1
],
CudaNdarray_HOST_STRIDES
(
self
)[
2
],
CudaNdarray_HOST_STRIDES
(
self
)[
3
],
CudaNdarray_HOST_STRIDES
(
self
)[
4
],
CudaNdarray_HOST_STRIDES
(
self
)[
5
],
CudaNdarray_DEV_DATA
(
other
),
other_strides
[
0
],
other_strides
[
1
],
other_strides
[
2
],
other_strides
[
3
],
other_strides
[
4
],
other_strides
[
5
]);
CNDA_THREAD_SYNC
;
cudaError_t
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Cuda error: %s: %s.
\n
"
,
"k4"
,
cudaGetErrorString
(
err
));
Py_XDECREF
(
new_other
);
return
-
1
;
}
}
break
;
default
:
default
:
{
{
PyErr_Format
(
PyErr_Format
(
...
...
theano/sandbox/cuda/tests/test_cuda_ndarray.py
浏览文件 @
2faeb62c
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论