Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
45f4b7a9
提交
45f4b7a9
authored
4月 10, 2014
作者:
f0k
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added gpu.release_gil config flag.
上级
ce1eeab9
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
39 行增加
和
13 行删除
+39
-13
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+23
-13
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+8
-0
nvcc_compiler.py
theano/sandbox/cuda/nvcc_compiler.py
+8
-0
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
45f4b7a9
...
...
@@ -200,7 +200,9 @@ int device_free(void *ptr)
// We need sync as the Theano's GC could remove intermediate variable that
// are still needed as the gpu kernel are running or in the queue.
CNDA_BEGIN_ALLOW_THREADS
cudaThreadSynchronize
();
CNDA_END_ALLOW_THREADS
cudaError_t
err
=
cudaFree
(
ptr
);
if
(
cudaSuccess
!=
err
)
...
...
@@ -518,10 +520,14 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args)
assert
(
PyArray_ITEMSIZE
(
rval
)
==
sizeof
(
real
));
cublasGetVector
(
PyArray_SIZE
(
rval
),
sizeof
(
real
),
npy_intp
rval_size
=
PyArray_SIZE
(
rval
);
void
*
rval_data
=
PyArray_DATA
(
rval
);
CNDA_BEGIN_ALLOW_THREADS
cublasGetVector
(
rval_size
,
sizeof
(
real
),
contiguous_self
->
devdata
,
1
,
PyArray_DATA
(
rval
),
1
);
CNDA_THREAD_SYNC
;
rval_data
,
1
);
//CNDA_THREAD_SYNC; // unneeded because cublasGetVector is blocking anyway
CNDA_END_ALLOW_THREADS
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
...
...
@@ -1217,14 +1223,12 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
//-10 could be any value different then 0.
int
cpu_err_var
=-
10
;
// We are not 100% sure that cudaMemcpy wait that the async gpu kernel are
// finished before doing the transfer. So we add this explicit sync as it
// is pretty fast. In a python loop, I ran 1 000 000 call in 1 second.
// It is better to be safe and not significatively slower than unsafe.
cudaThreadSynchronize
();
CNDA_BEGIN_ALLOW_THREADS
// As we execute cudaMemcpy on the default stream, it waits for all
// kernels (on all streams) to be finished before starting to copy
err
=
cudaMemcpy
(
&
cpu_err_var
,
err_var
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
CNDA_END_ALLOW_THREADS
if
(
cudaSuccess
!=
err
)
{
PyErr_Format
(
PyExc_RuntimeError
,
...
...
@@ -2838,7 +2842,9 @@ GetDeviceMemInfo(PyObject* _unused, PyObject* dummy)
PyObject
*
CudaNdarray_synchronize
(
PyObject
*
_unused
,
PyObject
*
dummy
)
{
CNDA_BEGIN_ALLOW_THREADS
cudaThreadSynchronize
();
CNDA_END_ALLOW_THREADS
Py_INCREF
(
Py_None
);
return
Py_None
;
}
...
...
@@ -3554,11 +3560,15 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
if
(
!
py_src
)
{
return
-
1
;
}
cublasSetVector
(
PyArray_SIZE
(
py_src
),
npy_intp
py_src_size
=
PyArray_SIZE
(
py_src
);
void
*
py_src_data
=
PyArray_DATA
(
py_src
);
CNDA_BEGIN_ALLOW_THREADS
cublasSetVector
(
py_src_size
,
sizeof
(
real
),
PyArray_DATA
(
py_src
)
,
1
,
py_src_data
,
1
,
self
->
devdata
,
1
);
CNDA_THREAD_SYNC
;
//CNDA_THREAD_SYNC; // unneeded because cublasSetVector is blocking anyway
CNDA_END_ALLOW_THREADS
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying data to device memory"
);
...
...
@@ -4952,7 +4962,7 @@ cnda_copy_structure_to_device(const CudaNdarray * self)
1
,
self
->
dev_structure
,
1
);
CNDA_THREAD_SYNC
;
//CNDA_THREAD_SYNC; // unneeded because cublasSetVector is blocking anyway
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying structure to device memory"
);
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
45f4b7a9
...
...
@@ -68,6 +68,14 @@ typedef float real;
#define CNDA_THREAD_SYNC cudaThreadSynchronize();
#endif
// Define shortcuts to implement the config.gpu.release_gil flag
#ifdef RELEASE_GIL
#define CNDA_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
#define CNDA_END_ALLOW_THREADS Py_END_ALLOW_THREADS
#else
#define CNDA_BEGIN_ALLOW_THREADS
#define CNDA_END_ALLOW_THREADS
#endif
#ifndef SHARED_SIZE
#define SHARED_SIZE (16*1024)
...
...
theano/sandbox/cuda/nvcc_compiler.py
浏览文件 @
45f4b7a9
...
...
@@ -30,6 +30,12 @@ AddConfigVar('nvcc.compiler_bindir',
StrParam
(
""
),
in_c_key
=
False
)
AddConfigVar
(
'gpu.release_gil'
,
"If True, theano will release the GIL when waiting for "
"GPU operations, allowing other Python threads to run"
,
BoolParam
(
False
),
in_c_key
=
True
)
user_provided_cuda_root
=
True
...
...
@@ -153,6 +159,8 @@ class NVCC_compiler(object):
flags
=
[
flag
for
flag
in
config
.
nvcc
.
flags
.
split
(
' '
)
if
flag
]
if
config
.
nvcc
.
fastmath
:
flags
.
append
(
'-use_fast_math'
)
if
config
.
gpu
.
release_gil
:
flags
.
append
(
'-DRELEASE_GIL'
)
cuda_ndarray_cuh_hash
=
hash_from_file
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
'cuda_ndarray.cuh'
))
flags
.
append
(
'-DCUDA_NDARRAY_CUH='
+
cuda_ndarray_cuh_hash
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论