Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1c5f42e0
提交
1c5f42e0
authored
2月 07, 2013
作者:
David Warde-Farley
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert changes to device_malloc calls; overload instead.
上级
629c173b
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
11 行增加
和
4 行删除
+11
-4
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+9
-3
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+2
-1
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
1c5f42e0
...
...
@@ -53,6 +53,12 @@ struct table_struct{
};
table_struct
_alloc_size_table
[
TABLE_SIZE
];
#endif
void
*
device_malloc
(
size_t
size
)
{
return
device_malloc
(
size
,
VERBOSE_DEVICE_MALLOC
);
}
void
*
device_malloc
(
size_t
size
,
int
verbose
)
{
void
*
rval
=
NULL
;
...
...
@@ -962,7 +968,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
// Create the memory place that will store the error information.
if
(
err_var
==
NULL
)
{
err_var
=
(
int
*
)
device_malloc
(
sizeof
(
int
)
,
VERBOSE_DEVICE_MALLOC
);
err_var
=
(
int
*
)
device_malloc
(
sizeof
(
int
));
if
(
!
err_var
)
{
// PyErr set by device_malloc
Py_DECREF
(
indices
);
Py_DECREF
(
out
);
...
...
@@ -2628,7 +2634,7 @@ static __global__ void get_gpu_ptr_size(int* dst)
PyObject
*
CudaNdarray_ptr_int_size
(
PyObject
*
_unused
,
PyObject
*
args
)
{
int
*
gpu_data
=
(
int
*
)
device_malloc
(
sizeof
(
int
)
*
2
,
VERBOSE_DEVICE_MALLOC
);
int
*
gpu_data
=
(
int
*
)
device_malloc
(
sizeof
(
int
)
*
2
);
if
(
gpu_data
==
NULL
){
return
PyErr_Format
(
PyExc_MemoryError
,
"CudaNdarray_ptr_int_size: Can't allocate memory on the gpu."
);
...
...
@@ -4524,7 +4530,7 @@ cnda_copy_structure_to_device(const CudaNdarray * self)
int
struct_size
=
cnda_structure_size
(
self
->
nd
);
if
(
struct_size
)
{
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
)
,
VERBOSE_DEVICE_MALLOC
);
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
));
if
(
NULL
==
self
->
dev_structure
)
{
return
-
1
;
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
1c5f42e0
...
...
@@ -51,6 +51,7 @@ typedef float real;
* device_malloc will set the Python error message before returning None.
* device_free will return nonzero on failure (after setting the python error message)
*/
DllExport
void
*
device_malloc
(
size_t
size
);
DllExport
void
*
device_malloc
(
size_t
size
,
int
verbose
);
DllExport
int
device_free
(
void
*
ptr
);
...
...
@@ -338,7 +339,7 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd, const i
return
-
1
;
}
self
->
devdata
=
(
float
*
)
device_malloc
(
size
*
sizeof
(
real
)
,
VERBOSE_DEVICE_MALLOC
);
self
->
devdata
=
(
float
*
)
device_malloc
(
size
*
sizeof
(
real
));
if
(
size
&&
!
self
->
devdata
)
{
CudaNdarray_set_nd
(
self
,
-
1
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论