Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5f0a8939
提交
5f0a8939
authored
5月 27, 2016
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add support for set_instead_of_inc to GpuAdvancedIncSubtensor1_dev20
上级
224af930
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
36 行增加
和
8 行删除
+36
-8
subtensor.py
theano/gpuarray/subtensor.py
+36
-8
没有找到文件。
theano/gpuarray/subtensor.py
浏览文件 @
5f0a8939
...
@@ -593,7 +593,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
...
@@ -593,7 +593,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
return
super
(
GpuAdvancedIncSubtensor1_dev20
,
self
)
.
perform
(
node
,
inp
,
out
)
return
super
(
GpuAdvancedIncSubtensor1_dev20
,
self
)
.
perform
(
node
,
inp
,
out
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
8
,)
return
(
9
,)
def
c_headers
(
self
):
def
c_headers
(
self
):
return
[
'<numpy_compat.h>'
,
'<gpuarray_helper.h>'
,
return
[
'<numpy_compat.h>'
,
'<gpuarray_helper.h>'
,
...
@@ -606,8 +606,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
...
@@ -606,8 +606,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
ctx
=
self
.
get_params
(
node
)
ctx
=
self
.
get_params
(
node
)
if
ctx
.
kind
!=
b
'cuda'
:
if
ctx
.
kind
!=
b
'cuda'
:
raise
NotImplementedError
(
"cuda only"
)
raise
NotImplementedError
(
"cuda only"
)
if
(
self
.
set_instead_of_inc
or
if
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
or
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
or
node
.
inputs
[
0
]
.
ndim
!=
2
or
node
.
inputs
[
0
]
.
ndim
!=
2
or
ctx
.
bin_id
[
-
2
]
<
b
'2'
):
ctx
.
bin_id
[
-
2
]
<
b
'2'
):
raise
NotImplementedError
(
"This case does not have C code yet."
)
raise
NotImplementedError
(
"This case does not have C code yet."
)
...
@@ -617,6 +616,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
...
@@ -617,6 +616,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
ind
=
inputs
[
2
]
ind
=
inputs
[
2
]
out
=
outputs
[
0
]
out
=
outputs
[
0
]
fail
=
sub
[
'fail'
]
fail
=
sub
[
'fail'
]
set_instead_of_inc
=
int
(
self
.
set_instead_of_inc
)
inplace
=
int
(
self
.
inplace
)
inplace
=
int
(
self
.
inplace
)
return
"""
return
"""
int err;
int err;
...
@@ -630,7 +630,7 @@ if (%(inplace)s) {
...
@@ -630,7 +630,7 @@ if (%(inplace)s) {
if (!
%(out)
s) {
if (!
%(out)
s) {
%(fail)
s
%(fail)
s
}
}
if (GpuArray_vector_add_fast(
%(out)
s,
%(y)
s,
%(ind)
s)) {
if (GpuArray_vector_add_fast(
%(out)
s,
%(y)
s,
%(ind)
s
,
%(set_instead_of_inc)
s
)) {
%(fail)
s
%(fail)
s
}
}
"""
%
locals
()
"""
%
locals
()
...
@@ -656,7 +656,7 @@ if (GpuArray_vector_add_fast(%(out)s, %(y)s, %(ind)s)) {
...
@@ -656,7 +656,7 @@ if (GpuArray_vector_add_fast(%(out)s, %(y)s, %(ind)s)) {
* This is an atomicAdd that works for doubles since that is not provided
* This is an atomicAdd that works for doubles since that is not provided
* natively by cuda.
* natively by cuda.
*/
*/
__device__ double atomicAdd(ga_double* address, ga_double val) {
__device__
ga_
double atomicAdd(ga_double* address, ga_double val) {
unsigned long long int* address_as_ull =
unsigned long long int* address_as_ull =
(unsigned long long int*)address;
(unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;
unsigned long long int old = *address_as_ull, assumed;
...
@@ -669,6 +669,11 @@ __device__ double atomicAdd(ga_double* address, ga_double val) {
...
@@ -669,6 +669,11 @@ __device__ double atomicAdd(ga_double* address, ga_double val) {
return __longlong_as_double(old);
return __longlong_as_double(old);
}
}
__device__ ga_double atomicExch(ga_double *address, ga_double val) {
return atomicExch((unsigned long long int *)address,
__double_as_longlong(val));
}
/*
/*
* This is a version of atomicAdd that works for half-floats. It may
* This is a version of atomicAdd that works for half-floats. It may
* read and write 2 bytes more than the size of the array if the array
* read and write 2 bytes more than the size of the array if the array
...
@@ -693,6 +698,19 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
...
@@ -693,6 +698,19 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
((ga_size)addr & 2) ? 0x4432 : 0x4410);
((ga_size)addr & 2) ? 0x4432 : 0x4410);
}
}
__device__ ga_half atomicExch(ga_half *addr, ga_half val) {
ga_uint *base = (ga_uint *)((ga_size)addr & ~2);
ga_uint old, assumed, new_;
old = *base;
do {
assumed = old;
new_ = __byte_perm(old, val, ((ga_size)addr & 2) ? 0x5410 : 0x3254);
old = atomicCAS(base, assumed, new_);
} while (assumed != old);
return (ga_half)__byte_perm(old, 0,
((ga_size)addr & 2) ? 0x4432 : 0x4410);
}
KERNEL void k_vector_add_fast(const ga_size numRowsX,
KERNEL void k_vector_add_fast(const ga_size numRowsX,
const ga_size numColsX,
const ga_size numColsX,
const ga_ssize stridesX0,
const ga_ssize stridesX0,
...
@@ -709,6 +727,7 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
...
@@ -709,6 +727,7 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
const ga_ssize stridesIndices,
const ga_ssize stridesIndices,
%(type_ind)
s *indices_arr,
%(type_ind)
s *indices_arr,
const ga_size offset_indices_arr,
const ga_size offset_indices_arr,
const int set_instead_of_inc,
ga_int *err)
ga_int *err)
{
{
X = (
%(type_x)
s *)(((char *)X)+offset_X);
X = (
%(type_x)
s *)(((char *)X)+offset_X);
...
@@ -723,7 +742,13 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
...
@@ -723,7 +742,13 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
x_row += numRowsX;
x_row += numRowsX;
ga_ssize y_row = i;
ga_ssize y_row = i;
if (x_row < numRowsX && x_row >= 0) {
if (x_row < numRowsX && x_row >= 0) {
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)], Y[(y_row * stridesY0) + (j * stridesY1)]);
if (set_instead_of_inc) {
atomicExch(&X[(x_row * stridesX0) + (j * stridesX1)],
Y[(y_row * stridesY0) + (j * stridesY1)]);
} else {
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)],
Y[(y_row * stridesY0) + (j * stridesY1)]);
}
} else {
} else {
*err = 1;
*err = 1;
}
}
...
@@ -735,7 +760,8 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
...
@@ -735,7 +760,8 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
params
=
[
params
=
[
'uintp'
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'uintp'
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'uintp'
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'uintp'
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
gpuarray
.
GpuArray
]
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'int'
,
gpuarray
.
GpuArray
]
return
[
Kernel
(
code
=
code
,
name
=
kname
,
params
=
params
,
return
[
Kernel
(
code
=
code
,
name
=
kname
,
params
=
params
,
flags
=
flags
,
objvar
=
k_var
)]
flags
=
flags
,
objvar
=
k_var
)]
...
@@ -753,7 +779,8 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
...
@@ -753,7 +779,8 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
return
super
(
GpuAdvancedIncSubtensor1_dev20
,
self
)
.
c_support_code_struct
(
node
,
nodename
)
+
"""
return
super
(
GpuAdvancedIncSubtensor1_dev20
,
self
)
.
c_support_code_struct
(
node
,
nodename
)
+
"""
int GpuArray_vector_add_fast(PyGpuArrayObject* py_self,
int GpuArray_vector_add_fast(PyGpuArrayObject* py_self,
PyGpuArrayObject* py_other,
PyGpuArrayObject* py_other,
PyGpuArrayObject *indices_arr)
PyGpuArrayObject *indices_arr,
const int set_instead_of_inc)
{
{
size_t threads_per_block[3] = {std::min(PyGpuArray_DIMS(py_self)[1], (size_t)256), 1, 1};
size_t threads_per_block[3] = {std::min(PyGpuArray_DIMS(py_self)[1], (size_t)256), 1, 1};
size_t n_blocks[3] = {std::min(PyGpuArray_SIZE(indices_arr), (size_t)4096), 1, 1};
size_t n_blocks[3] = {std::min(PyGpuArray_SIZE(indices_arr), (size_t)4096), 1, 1};
...
@@ -789,6 +816,7 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
...
@@ -789,6 +816,7 @@ __device__ ga_half atomicAdd(ga_half *addr, ga_half val) {
(void *)&stride_ind,
(void *)&stride_ind,
(void *)indices_arr->ga.data,
(void *)indices_arr->ga.data,
(void *)&indices_arr->ga.offset,
(void *)&indices_arr->ga.offset,
(void *)&set_instead_of_inc,
(void *)errbuf};
(void *)errbuf};
err = GpuKernel_call(&
%(k_var)
s, 3, threads_per_block, n_blocks, 0, kernel_params);
err = GpuKernel_call(&
%(k_var)
s, 3, threads_per_block, n_blocks, 0, kernel_params);
if (err != GA_NO_ERROR) {
if (err != GA_NO_ERROR) {
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论