Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
fd75d2c2
提交
fd75d2c2
authored
12月 16, 2015
作者:
Balázs Hidasi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
GPU implementation for GpuAdvancedIncSubtensor1_dev20 using atomicExch()
上级
725b7a3f
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
16 行增加
和
8 行删除
+16
-8
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+16
-8
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
fd75d2c2
...
@@ -3035,8 +3035,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3035,8 +3035,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
active_device_no
=
theano
.
sandbox
.
cuda
.
active_device_number
()
active_device_no
=
theano
.
sandbox
.
cuda
.
active_device_number
()
compute_capability
=
device_properties
(
active_device_no
)[
'major'
]
compute_capability
=
device_properties
(
active_device_no
)[
'major'
]
if
((
self
.
set_instead_of_inc
)
or
if
((
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
)
or
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
)
or
(
node
.
inputs
[
0
]
.
ndim
!=
2
)
or
(
node
.
inputs
[
0
]
.
ndim
!=
2
)
or
(
compute_capability
<
2
)):
(
compute_capability
<
2
)):
raise
NotImplementedError
(
"This case does not have C code yet."
)
raise
NotImplementedError
(
"This case does not have C code yet."
)
...
@@ -3047,6 +3046,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3047,6 +3046,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
out
=
outputs
[
0
]
out
=
outputs
[
0
]
fail
=
sub
[
'fail'
]
fail
=
sub
[
'fail'
]
inplace
=
int
(
self
.
inplace
)
inplace
=
int
(
self
.
inplace
)
set_instead_of_inc
=
int
(
self
.
set_instead_of_inc
)
return
"""
return
"""
Py_XDECREF(
%(out)
s);
Py_XDECREF(
%(out)
s);
if (!
%(inplace)
s) {
if (!
%(inplace)
s) {
...
@@ -3056,7 +3056,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3056,7 +3056,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
Py_XINCREF(
%(out)
s);
Py_XINCREF(
%(out)
s);
}
}
if (CudaNdarray_vector_add_
fast(
%(out)
s,
%(y)
s,
%(ind
)
s) != 0){
if (CudaNdarray_vector_add_
or_replace_fast(
%(out)
s,
%(y)
s,
%(ind)
s,
%(set_instead_of_inc
)
s) != 0){
%(fail)
s
%(fail)
s
}
}
...
@@ -3068,7 +3068,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3068,7 +3068,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
"""
return
"""
__global__ void k_vector_add_fast(int numRowsX,
__global__ void k_vector_add_
or_replace_
fast(int numRowsX,
int numColsX,
int numColsX,
int stridesX0,
int stridesX0,
int stridesX1,
int stridesX1,
...
@@ -3080,6 +3080,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3080,6 +3080,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
float *Y ,
float *Y ,
long *d_indices_arr,
long *d_indices_arr,
int num,
int num,
const int set_instead_of_inc,
int* err)
int* err)
{
{
for (int i = (blockIdx.x); i < num; i += gridDim.x)
for (int i = (blockIdx.x); i < num; i += gridDim.x)
...
@@ -3091,8 +3092,13 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3091,8 +3092,13 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
x_row += numRowsX;
x_row += numRowsX;
int y_row = i;
int y_row = i;
if(x_row < numRowsX && x_row >= 0){
if(x_row < numRowsX && x_row >= 0){
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)],
if(set_instead_of_inc){
atomicExch(&X[(x_row * stridesX0) + (j * stridesX1)],
Y[(y_row * stridesY0) + (j * stridesY1)]);
Y[(y_row * stridesY0) + (j * stridesY1)]);
} else{
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)],
Y[(y_row * stridesY0) + (j * stridesY1)]);
}
} else {
} else {
*err = 1;
*err = 1;
}
}
...
@@ -3101,8 +3107,9 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3101,8 +3107,9 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
return;
return;
}
}
int CudaNdarray_vector_add_fast(CudaNdarray* py_self,
int CudaNdarray_vector_add_or_replace_fast(CudaNdarray* py_self,
CudaNdarray* py_other, PyArrayObject *indices_arr)
CudaNdarray* py_other, PyArrayObject *indices_arr,
const int set_instead_of_inc)
{
{
if(init_err_var()!= 0) return -1;
if(init_err_var()!= 0) return -1;
...
@@ -3144,7 +3151,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3144,7 +3151,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
return -1;
return -1;
}
}
k_vector_add_fast<<<n_blocks, n_threads>>>(
k_vector_add_
or_replace_
fast<<<n_blocks, n_threads>>>(
shapeX[0],
shapeX[0],
shapeX[1],
shapeX[1],
strX[0],
strX[0],
...
@@ -3157,6 +3164,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
...
@@ -3157,6 +3164,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
CudaNdarray_DEV_DATA(py_other),
CudaNdarray_DEV_DATA(py_other),
d_indices_arr,
d_indices_arr,
PyArray_SIZE(indices_arr),
PyArray_SIZE(indices_arr),
set_instead_of_inc,
err_var
err_var
);
);
int index_err = check_err_var();
int index_err = check_err_var();
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论