Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f441295b
提交
f441295b
authored
9月 20, 2012
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
postpone the creating of the device structure to when we need it.
This is a significant speed up with the gc as most of the time, we don't need it and allocating on the GPU is slow.
上级
c50a2db1
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
29 行增加
和
16 行删除
+29
-16
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+24
-1
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+5
-15
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
f441295b
...
...
@@ -4385,7 +4385,30 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2)
int
cnda_copy_structure_to_device
(
const
CudaNdarray
*
self
)
{
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
sizeof
(
int
),
self
->
host_structure
,
1
,
self
->
dev_structure
,
1
);
//If the device structure do not exists, create it.
//We allocate it here as we do not need it often.
//In fact, we need it so infrequently that we expect
//that most object won't need it. Not allocating it
//save a significant when creating object.
//This speed up a benchmark by 8% with the gc.
if
(
!
self
->
dev_structure
)
{
int
struct_size
=
cnda_structure_size
(
self
->
nd
);
if
(
struct_size
)
{
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
));
if
(
NULL
==
self
->
dev_structure
)
{
return
-
1
;
}
}
}
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
sizeof
(
int
),
self
->
host_structure
,
1
,
self
->
dev_structure
,
1
);
CNDA_THREAD_SYNC
;
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
f441295b
...
...
@@ -82,8 +82,9 @@ struct CudaNdarray
//device pointers (allocated by cudaMalloc)
mutable
int
dev_structure_fresh
;
//dev_structure should be accessed via macros, otherwise may not be synchronized
int
*
dev_structure
;
//dim0, dim1, ..., stride0, stride1, ...
//dev_structure should be accessed via macros, otherwise may not be
//synchronized. The macro will allocate it when needed.
mutable
int
*
dev_structure
;
//dim0, dim1, ..., stride0, stride1, ...
real
*
devdata
;
//pointer to data element [0,..,0].
};
...
...
@@ -251,19 +252,8 @@ CudaNdarray_set_nd(CudaNdarray * self, const int nd)
{
self
->
host_structure
[
i
]
=
0
;
}
int
struct_size
=
cnda_structure_size
(
nd
);
if
(
struct_size
)
{
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
));
if
(
NULL
==
self
->
dev_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
dev_structure
=
NULL
;
return
-
1
;
}
}
//The device structure will be created in cnda_copy_structure_to_device
//if needed.
self
->
nd
=
nd
;
self
->
dev_structure_fresh
=
0
;
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论