Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
54efc47a
提交
54efc47a
authored
7月 14, 2014
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Switch over to the v2 API.
上级
fbd201b0
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
34 行增加
和
39 行删除
+34
-39
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+30
-38
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+4
-1
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
54efc47a
...
@@ -41,6 +41,8 @@
...
@@ -41,6 +41,8 @@
#define CNDA_END_ALLOW_THREADS
#define CNDA_END_ALLOW_THREADS
#endif
#endif
cublasHandle_t
handle
;
/////////////////////////
/////////////////////////
// Alloc and Free
// Alloc and Free
/////////////////////////
/////////////////////////
...
@@ -3538,19 +3540,25 @@ CudaNdarray_New(int nd)
...
@@ -3538,19 +3540,25 @@ CudaNdarray_New(int nd)
int
int
cublas_init
()
cublas_init
()
{
{
cublasInit
();
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasCreate
(
&
handle
))
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error initializing device"
);
PyErr_SetString
(
PyExc_RuntimeError
,
"error initializing device"
);
return
-
1
;
return
-
1
;
}
}
// Set the default stream as the one to execute on (default)
cublasSetStream
(
handle
,
NULL
);
// Pointer to scalars are on the host (also default)
cublasSetPointerMode
(
handle
,
CUBLAS_POINTER_MODE_HOST
);
// atomics can be used in kernels to speed up operations (not default)
// This may lead to a slight variance from run to run in some operations
cublasSetAtomicsMode
(
handle
,
CUBLAS_ATOMICS_ALLOWED
);
return
0
;
return
0
;
}
}
int
int
cublas_shutdown
()
cublas_shutdown
()
{
{
cublasShutdown
();
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasDestroy
(
handle
))
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error shutting down device"
);
PyErr_SetString
(
PyExc_RuntimeError
,
"error shutting down device"
);
return
-
1
;
return
-
1
;
...
@@ -3579,14 +3587,15 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
...
@@ -3579,14 +3587,15 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
}
}
npy_intp
py_src_size
=
PyArray_SIZE
(
py_src
);
npy_intp
py_src_size
=
PyArray_SIZE
(
py_src
);
void
*
py_src_data
=
PyArray_DATA
(
py_src
);
void
*
py_src_data
=
PyArray_DATA
(
py_src
);
cublasStatus_t
err
;
CNDA_BEGIN_ALLOW_THREADS
CNDA_BEGIN_ALLOW_THREADS
cublasSetVector
(
py_src_size
,
err
=
cublasSetVector
(
py_src_size
,
sizeof
(
real
),
sizeof
(
real
),
py_src_data
,
1
,
py_src_data
,
1
,
self
->
devdata
,
1
);
self
->
devdata
,
1
);
//CNDA_THREAD_SYNC; // unneeded because cublasSetVector is blocking anyway
//CNDA_THREAD_SYNC; // unneeded because cublasSetVector is blocking anyway
CNDA_END_ALLOW_THREADS
CNDA_END_ALLOW_THREADS
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
()
)
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
{
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying data to device memory"
);
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying data to device memory"
);
Py_DECREF
(
py_src
);
Py_DECREF
(
py_src
);
...
@@ -3750,11 +3759,12 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
...
@@ -3750,11 +3759,12 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
if
(
verbose
)
if
(
verbose
)
fprintf
(
stderr
,
"Copying contiguous vector with cublasScopy
\n
"
);
fprintf
(
stderr
,
"Copying contiguous vector with cublasScopy
\n
"
);
cublasScopy
(
size
,
CudaNdarray_DEV_DATA
(
other
),
1
,
cublasStatus_t
err
;
err
=
cublasScopy
(
handle
,
size
,
CudaNdarray_DEV_DATA
(
other
),
1
,
CudaNdarray_DEV_DATA
(
self
),
1
);
CudaNdarray_DEV_DATA
(
self
),
1
);
CNDA_THREAD_SYNC
;
CNDA_THREAD_SYNC
;
Py_XDECREF
(
new_other
);
Py_XDECREF
(
new_other
);
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
()
)
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
{
{
PyErr_SetString
(
PyExc_RuntimeError
,
"Error copying memory"
);
PyErr_SetString
(
PyExc_RuntimeError
,
"Error copying memory"
);
return
-
1
;
return
-
1
;
...
@@ -3920,22 +3930,6 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -3920,22 +3930,6 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
return
-
1
;
return
-
1
;
}
}
#if PRECHECK_ERROR
cublasStatus
prevErr
=
cublasGetError
();
if
(
CUBLAS_STATUS_SUCCESS
!=
prevErr
)
{
//I don't know why, but I need to remove the cuda error too.
//Otherwise, the clean up before raising the Python error cause error too!
//So we don't see this python error.
fprintf
(
stderr
,
"CudaNdarray_sgemm: Prev cublas error %s"
,
cublasGetErrorString
(
prevErr
));
PyErr_Format
(
PyExc_RuntimeError
,
"CudaNdarray_sgemm: Prev cublas error %s"
,
cublasGetErrorString
(
prevErr
));
return
-
1
;
}
#endif
// We must allow dimensions to be zeros.
// We must allow dimensions to be zeros.
if
((
CudaNdarray_HOST_DIMS
(
A
)[
1
]
!=
CudaNdarray_HOST_DIMS
(
B
)[
0
])
if
((
CudaNdarray_HOST_DIMS
(
A
)[
1
]
!=
CudaNdarray_HOST_DIMS
(
B
)[
0
])
||
(
CudaNdarray_HOST_DIMS
(
A
)[
0
]
!=
CudaNdarray_HOST_DIMS
(
C
)[
0
])
||
(
CudaNdarray_HOST_DIMS
(
A
)[
0
]
!=
CudaNdarray_HOST_DIMS
(
C
)[
0
])
...
@@ -4055,8 +4049,8 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -4055,8 +4049,8 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
float
*
a
=
CudaNdarray_DEV_DATA
(
A
);
float
*
a
=
CudaNdarray_DEV_DATA
(
A
);
float
*
b
=
CudaNdarray_DEV_DATA
(
B
);
float
*
b
=
CudaNdarray_DEV_DATA
(
B
);
float
*
c
=
CudaNdarray_DEV_DATA
(
C
);
float
*
c
=
CudaNdarray_DEV_DATA
(
C
);
c
har
N
=
'N'
;
c
ublasOperation_t
N
=
CUBLAS_OP_N
;
c
har
T
=
'T'
;
c
ublasOperation_t
T
=
CUBLAS_OP_T
;
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
// There should be no negative stride at that point
// There should be no negative stride at that point
#define CHK_STRIDE_SGEMM(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz) \
#define CHK_STRIDE_SGEMM(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz) \
...
@@ -4064,7 +4058,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -4064,7 +4058,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
if (sy == 0){sy = 1;}\
if (sy == 0){sy = 1;}\
if (sz == 0){sz = 1;}\
if (sz == 0){sz = 1;}\
if ((sx > 0) && (sy > 0) && (sz > 0)) { \
if ((sx > 0) && (sy > 0) && (sz > 0)) { \
cublasSgemm(
T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz); \
err = cublasSgemm(handle,
T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz); \
} else { \
} else { \
PyErr_SetString(PyExc_AssertionError, "negative stride to sGemm");\
PyErr_SetString(PyExc_AssertionError, "negative stride to sGemm");\
Py_XDECREF(A_new);\
Py_XDECREF(A_new);\
...
@@ -4072,6 +4066,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -4072,6 +4066,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
return -1; \
return -1; \
}
}
cublasStatus_t
err
;
switch
(
unit
)
switch
(
unit
)
{
{
case
0x000
:
CHK_STRIDE_SGEMM
(
N
,
N
,
CudaNdarray_HOST_DIMS
(
C
)[
1
],
CudaNdarray_HOST_DIMS
(
C
)[
0
],
CudaNdarray_HOST_DIMS
(
A
)[
1
],
alpha
,
b
,
sb_0
,
a
,
sa_0
,
beta
,
c
,
sc_0
);
break
;
case
0x000
:
CHK_STRIDE_SGEMM
(
N
,
N
,
CudaNdarray_HOST_DIMS
(
C
)[
1
],
CudaNdarray_HOST_DIMS
(
C
)[
0
],
CudaNdarray_HOST_DIMS
(
A
)[
1
],
alpha
,
b
,
sb_0
,
a
,
sa_0
,
beta
,
c
,
sc_0
);
break
;
...
@@ -4089,7 +4084,6 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -4089,7 +4084,6 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
Py_XDECREF
(
A_new
);
Py_XDECREF
(
A_new
);
Py_XDECREF
(
B_new
);
Py_XDECREF
(
B_new
);
cublasStatus
err
=
cublasGetError
();
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
{
{
PyErr_Format
(
PyExc_RuntimeError
,
PyErr_Format
(
PyExc_RuntimeError
,
...
@@ -4187,12 +4181,13 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -4187,12 +4181,13 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
if
(
sa_1
==
0
)
if
(
sa_1
==
0
)
sa_1
=
1
;
sa_1
=
1
;
cublasStatus_t
err
;
if
(
CudaNdarray_SIZE
(
C
))
{
if
(
CudaNdarray_SIZE
(
C
))
{
if
((
CudaNdarray_HOST_DIMS
(
A
)[
0
]
<=
1
)
if
((
CudaNdarray_HOST_DIMS
(
A
)[
0
]
<=
1
)
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
==
1
)
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
==
1
)
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
>
0
)))
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
>
0
)))
{
{
cublasSgemv
(
'N'
,
err
=
cublasSgemv
(
handle
,
CUBLAS_OP_N
,
CudaNdarray_HOST_DIMS
(
A
)[
0
],
CudaNdarray_HOST_DIMS
(
A
)[
1
],
CudaNdarray_HOST_DIMS
(
A
)[
0
],
CudaNdarray_HOST_DIMS
(
A
)[
1
],
alpha
,
alpha
,
CudaNdarray_DEV_DATA
(
A
),
sa_1
,
CudaNdarray_DEV_DATA
(
A
),
sa_1
,
...
@@ -4204,7 +4199,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -4204,7 +4199,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
==
1
)
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
==
1
)
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
>
0
)))
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
>
0
)))
{
{
cublasSgemv
(
'T'
,
err
=
cublasSgemv
(
handle
,
CUBLAS_OP_T
,
CudaNdarray_HOST_DIMS
(
A
)[
1
],
CudaNdarray_HOST_DIMS
(
A
)[
0
],
CudaNdarray_HOST_DIMS
(
A
)[
1
],
CudaNdarray_HOST_DIMS
(
A
)[
0
],
alpha
,
alpha
,
CudaNdarray_DEV_DATA
(
A
),
sa_0
,
CudaNdarray_DEV_DATA
(
A
),
sa_0
,
...
@@ -4235,7 +4230,6 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -4235,7 +4230,6 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
Py_XDECREF
(
A_new
);
Py_XDECREF
(
A_new
);
Py_XDECREF
(
B_new
);
Py_XDECREF
(
B_new
);
cublasStatus
err
=
cublasGetError
();
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
{
{
PyErr_Format
(
PyExc_RuntimeError
,
PyErr_Format
(
PyExc_RuntimeError
,
...
@@ -4303,13 +4297,14 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
...
@@ -4303,13 +4297,14 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
int
sa_1
=
(
CudaNdarray_HOST_DIMS
(
A
)[
1
]
>
1
)
?
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
int
sa_1
=
(
CudaNdarray_HOST_DIMS
(
A
)[
1
]
>
1
)
?
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
:
CudaNdarray_HOST_DIMS
(
A
)[
0
];
:
CudaNdarray_HOST_DIMS
(
A
)[
0
];
cublasStatus
err
;
if
(
CudaNdarray_SIZE
(
A
)){
if
(
CudaNdarray_SIZE
(
A
)){
// If A is in col-major
// If A is in col-major
if
((
CudaNdarray_HOST_DIMS
(
A
)[
0
]
<=
1
)
if
((
CudaNdarray_HOST_DIMS
(
A
)[
0
]
<=
1
)
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
==
1
)
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
==
1
)
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
>
0
)))
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
>
0
)))
{
{
cublasSger
(
CudaNdarray_HOST_DIMS
(
x
)[
0
],
CudaNdarray_HOST_DIMS
(
y
)[
0
],
alpha
,
err
=
cublasSger
(
handle
,
CudaNdarray_HOST_DIMS
(
x
)[
0
],
CudaNdarray_HOST_DIMS
(
y
)[
0
],
alpha
,
CudaNdarray_DEV_DATA
(
x
),
x_strides
,
CudaNdarray_DEV_DATA
(
x
),
x_strides
,
CudaNdarray_DEV_DATA
(
y
),
y_strides
,
CudaNdarray_DEV_DATA
(
y
),
y_strides
,
CudaNdarray_DEV_DATA
(
A
),
sa_1
);
CudaNdarray_DEV_DATA
(
A
),
sa_1
);
...
@@ -4319,7 +4314,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
...
@@ -4319,7 +4314,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
==
1
)
||
((
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
==
1
)
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
>
0
)))
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
>
0
)))
{
{
cublasSger
(
CudaNdarray_HOST_DIMS
(
y
)[
0
],
CudaNdarray_HOST_DIMS
(
x
)[
0
],
alpha
,
err
=
cublasSger
(
handle
,
CudaNdarray_HOST_DIMS
(
y
)[
0
],
CudaNdarray_HOST_DIMS
(
x
)[
0
],
alpha
,
CudaNdarray_DEV_DATA
(
y
),
y_strides
,
CudaNdarray_DEV_DATA
(
y
),
y_strides
,
CudaNdarray_DEV_DATA
(
x
),
x_strides
,
CudaNdarray_DEV_DATA
(
x
),
x_strides
,
CudaNdarray_DEV_DATA
(
A
),
sa_0
);
CudaNdarray_DEV_DATA
(
A
),
sa_0
);
...
@@ -4338,7 +4333,6 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
...
@@ -4338,7 +4333,6 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
Py_XDECREF
(
x_new
);
Py_XDECREF
(
x_new
);
Py_XDECREF
(
y_new
);
Py_XDECREF
(
y_new
);
cublasStatus
err
=
cublasGetError
();
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
{
{
PyErr_Format
(
PyExc_RuntimeError
,
PyErr_Format
(
PyExc_RuntimeError
,
...
@@ -4973,14 +4967,12 @@ cnda_copy_structure_to_device(const CudaNdarray * self)
...
@@ -4973,14 +4967,12 @@ cnda_copy_structure_to_device(const CudaNdarray * self)
}
}
}
}
}
}
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
if
(
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
sizeof
(
int
),
sizeof
(
int
),
self
->
host_structure
,
self
->
host_structure
,
1
,
1
,
self
->
dev_structure
,
self
->
dev_structure
,
1
);
1
)
!=
CUBLAS_STATUS_SUCCESS
)
//CNDA_THREAD_SYNC; // unneeded because cublasSetVector is blocking anyway
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying structure to device memory"
);
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying structure to device memory"
);
return
-
1
;
return
-
1
;
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
54efc47a
...
@@ -40,7 +40,7 @@
...
@@ -40,7 +40,7 @@
#endif
#endif
#include <cublas.h>
#include <cublas
_v2
.h>
#ifdef _WIN32
#ifdef _WIN32
#ifdef _CUDA_NDARRAY_C
#ifdef _CUDA_NDARRAY_C
...
@@ -81,6 +81,9 @@ typedef float real;
...
@@ -81,6 +81,9 @@ typedef float real;
#define VERBOSE_DEVICE_MALLOC 1
#define VERBOSE_DEVICE_MALLOC 1
#define NO_VERBOSE_DEVICE_MALLOC 0
#define NO_VERBOSE_DEVICE_MALLOC 0
/* Use this handle to make cublas calls */
extern
cublasHandle_t
handle
;
/**
/**
* Allocation and freeing of device memory should go through these functions so that the lib can track memory usage.
* Allocation and freeing of device memory should go through these functions so that the lib can track memory usage.
*
*
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论