Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f2693c57
提交
f2693c57
authored
2月 09, 2012
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make dev_structure_fresh mutable, add a few const
This enables calling CudaNdarray_Copy on a const CudaNdarray*, which was previously impossible.
上级
22f11d45
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
18 行增加
和
18 行删除
+18
-18
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+10
-10
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+8
-8
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
f2693c57
...
@@ -453,7 +453,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
...
@@ -453,7 +453,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
PyObject
*
CudaNdarray_Copy
(
CudaNdarray
*
self
)
PyObject
*
CudaNdarray_Copy
(
const
CudaNdarray
*
self
)
{
{
PyObject
*
rval
=
CudaNdarray_New
();
PyObject
*
rval
=
CudaNdarray_New
();
if
((
!
rval
)
||
(
-
1
==
self
->
nd
))
if
((
!
rval
)
||
(
-
1
==
self
->
nd
))
...
@@ -2777,7 +2777,7 @@ static __global__ void k_copy_1d(const int N, const float * x, const int sx, flo
...
@@ -2777,7 +2777,7 @@ static __global__ void k_copy_1d(const int N, const float * x, const int sx, flo
}
}
//copy from other into self
//copy from other into self
int
CudaNdarray_CopyFromCudaNdarray
(
CudaNdarray
*
self
,
CudaNdarray
*
other
,
bool
unbroadcast
)
int
CudaNdarray_CopyFromCudaNdarray
(
CudaNdarray
*
self
,
const
CudaNdarray
*
other
,
bool
unbroadcast
)
{
{
int
verbose
=
0
;
int
verbose
=
0
;
if
(
verbose
>
1
)
fprintf
(
stderr
,
"CudaNdarray_CopyFromCudaNdarray
\n
"
);
if
(
verbose
>
1
)
fprintf
(
stderr
,
"CudaNdarray_CopyFromCudaNdarray
\n
"
);
...
@@ -2856,7 +2856,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, CudaNdarray * other, boo
...
@@ -2856,7 +2856,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, CudaNdarray * other, boo
// call worker routine
// call worker routine
unsigned
int
n_blocks
=
std
::
min
(
size
,
(
unsigned
int
)
NUM_VECTOR_OP_BLOCKS
);
unsigned
int
n_blocks
=
std
::
min
(
size
,
(
unsigned
int
)
NUM_VECTOR_OP_BLOCKS
);
unsigned
int
threads_per_block
=
std
::
min
(
ceil_intdiv
(
size
,
n_blocks
),
(
unsigned
int
)
NUM_VECTOR_OP_THREADS_PER_BLOCK
);
unsigned
int
threads_per_block
=
std
::
min
(
ceil_intdiv
(
size
,
n_blocks
),
(
unsigned
int
)
NUM_VECTOR_OP_THREADS_PER_BLOCK
);
CudaNdarray
*
cuda_dims
=
other
;
const
CudaNdarray
*
cuda_dims
=
other
;
if
(
unbroadcast
)
if
(
unbroadcast
)
cuda_dims
=
self
;
cuda_dims
=
self
;
//copy from other into self
//copy from other into self
...
@@ -3099,7 +3099,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -3099,7 +3099,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
return
0
;
return
0
;
}
}
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
)
{
int
CudaNdarray_sger
(
float
alpha
,
const
CudaNdarray
*
x
,
const
CudaNdarray
*
y
,
CudaNdarray
*
A
)
{
if
(
x
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg x to sger"
);
return
-
1
;
}
if
(
x
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg x to sger"
);
return
-
1
;
}
if
(
y
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg y to sger"
);
return
-
1
;
}
if
(
y
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg y to sger"
);
return
-
1
;
}
if
(
A
->
nd
!=
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-matrix arg A to sger"
);
return
-
1
;
}
if
(
A
->
nd
!=
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-matrix arg A to sger"
);
return
-
1
;
}
...
@@ -3122,7 +3122,7 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
...
@@ -3122,7 +3122,7 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
}
}
// Since Sger expects A in col-major, we invert x and y to fake this.
// Since Sger expects A in col-major, we invert x and y to fake this.
int
x_strides
=
CudaNdarray_HOST_STRIDES
(
x
)[
0
];
int
x_strides
=
CudaNdarray_HOST_STRIDES
(
x
)[
0
];
CudaNdarray
*
x_
=
x
;
const
CudaNdarray
*
x_
=
x
;
if
(
x_strides
==
0
){
if
(
x_strides
==
0
){
if
(
CudaNdarray_HOST_DIMS
(
x
)[
0
]
!=
1
){
if
(
CudaNdarray_HOST_DIMS
(
x
)[
0
]
!=
1
){
PyErr_Format
(
PyExc_RuntimeError
,
PyErr_Format
(
PyExc_RuntimeError
,
...
@@ -3138,7 +3138,7 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
...
@@ -3138,7 +3138,7 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
}
}
int
y_strides
=
CudaNdarray_HOST_STRIDES
(
y
)[
0
];
int
y_strides
=
CudaNdarray_HOST_STRIDES
(
y
)[
0
];
CudaNdarray
*
y_
=
y
;
const
CudaNdarray
*
y_
=
y
;
if
(
y_strides
==
0
){
if
(
y_strides
==
0
){
if
(
CudaNdarray_HOST_DIMS
(
y
)[
0
]
!=
1
){
if
(
CudaNdarray_HOST_DIMS
(
y
)[
0
]
!=
1
){
PyErr_Format
(
PyExc_RuntimeError
,
PyErr_Format
(
PyExc_RuntimeError
,
...
@@ -3816,7 +3816,7 @@ CudaNdarray_set_stride(CudaNdarray * self, int idx, int s)
...
@@ -3816,7 +3816,7 @@ CudaNdarray_set_stride(CudaNdarray * self, int idx, int s)
int
int
cnda_copy_structure_to_device
(
CudaNdarray
*
self
)
cnda_copy_structure_to_device
(
const
CudaNdarray
*
self
)
{
{
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
sizeof
(
int
),
self
->
host_structure
,
1
,
self
->
dev_structure
,
1
);
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
sizeof
(
int
),
self
->
host_structure
,
1
,
self
->
dev_structure
,
1
);
CNDA_THREAD_SYNC
;
CNDA_THREAD_SYNC
;
...
@@ -3830,7 +3830,7 @@ cnda_copy_structure_to_device(CudaNdarray * self)
...
@@ -3830,7 +3830,7 @@ cnda_copy_structure_to_device(CudaNdarray * self)
}
}
const
int
*
const
int
*
CudaNdarray_DEV_DIMS
(
CudaNdarray
*
self
)
CudaNdarray_DEV_DIMS
(
const
CudaNdarray
*
self
)
{
{
if
(
!
self
->
dev_structure_fresh
)
if
(
!
self
->
dev_structure_fresh
)
{
{
...
@@ -3840,7 +3840,7 @@ CudaNdarray_DEV_DIMS(CudaNdarray * self)
...
@@ -3840,7 +3840,7 @@ CudaNdarray_DEV_DIMS(CudaNdarray * self)
return
self
->
dev_structure
;
return
self
->
dev_structure
;
}
}
const
int
*
const
int
*
CudaNdarray_DEV_STRIDES
(
CudaNdarray
*
self
)
CudaNdarray_DEV_STRIDES
(
const
CudaNdarray
*
self
)
{
{
if
(
!
self
->
dev_structure_fresh
)
if
(
!
self
->
dev_structure_fresh
)
{
{
...
@@ -3850,7 +3850,7 @@ CudaNdarray_DEV_STRIDES(CudaNdarray * self)
...
@@ -3850,7 +3850,7 @@ CudaNdarray_DEV_STRIDES(CudaNdarray * self)
return
self
->
dev_structure
+
self
->
nd
;
return
self
->
dev_structure
+
self
->
nd
;
}
}
const
int
*
const
int
*
CudaNdarray_DEV_LOG2DIMS
(
CudaNdarray
*
self
)
CudaNdarray_DEV_LOG2DIMS
(
const
CudaNdarray
*
self
)
{
{
if
(
!
self
->
dev_structure_fresh
)
if
(
!
self
->
dev_structure_fresh
)
{
{
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
f2693c57
...
@@ -81,7 +81,7 @@ struct CudaNdarray
...
@@ -81,7 +81,7 @@ struct CudaNdarray
//device pointers (allocated by cudaMalloc)
//device pointers (allocated by cudaMalloc)
int
dev_structure_fresh
;
mutable
int
dev_structure_fresh
;
//dev_structure should be accessed via macros, otherwise may not be synchronized
//dev_structure should be accessed via macros, otherwise may not be synchronized
int
*
dev_structure
;
//dim0, dim1, ..., stride0, stride1, ...
int
*
dev_structure
;
//dim0, dim1, ..., stride0, stride1, ...
real
*
devdata
;
//pointer to data element [0,..,0].
real
*
devdata
;
//pointer to data element [0,..,0].
...
@@ -154,11 +154,11 @@ CudaNdarray_set_stride(CudaNdarray * self, int idx, int s);
...
@@ -154,11 +154,11 @@ CudaNdarray_set_stride(CudaNdarray * self, int idx, int s);
*
*
* This means: recalculate the log2dims and transfer structure to the card
* This means: recalculate the log2dims and transfer structure to the card
*/
*/
DllExport
int
cnda_copy_structure_to_device
(
CudaNdarray
*
self
);
DllExport
int
cnda_copy_structure_to_device
(
const
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_DIMS
(
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_DIMS
(
const
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_STRIDES
(
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_STRIDES
(
const
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_LOG2DIMS
(
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_LOG2DIMS
(
const
CudaNdarray
*
self
);
DllExport
float
*
CudaNdarray_DEV_DATA
(
const
CudaNdarray
*
self
);
DllExport
float
*
CudaNdarray_DEV_DATA
(
const
CudaNdarray
*
self
);
/**
/**
...
@@ -283,7 +283,7 @@ DllExport PyObject * CudaNdarray_DeepCopy(CudaNdarray * self, PyObject * memo);
...
@@ -283,7 +283,7 @@ DllExport PyObject * CudaNdarray_DeepCopy(CudaNdarray * self, PyObject * memo);
/**
/**
* Return an independent copy of self
* Return an independent copy of self
*/
*/
DllExport
PyObject
*
CudaNdarray_Copy
(
CudaNdarray
*
self
);
DllExport
PyObject
*
CudaNdarray_Copy
(
const
CudaNdarray
*
self
);
/**
/**
* Return a new object obtained by summing over the dimensions for which there is a 1 in the mask.
* Return a new object obtained by summing over the dimensions for which there is a 1 in the mask.
...
@@ -302,7 +302,7 @@ DllExport int CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj);
...
@@ -302,7 +302,7 @@ DllExport int CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj);
*
*
* self is reallocated to have the correct dimensions if necessary.
* self is reallocated to have the correct dimensions if necessary.
*/
*/
DllExport
int
CudaNdarray_CopyFromCudaNdarray
(
CudaNdarray
*
self
,
CudaNdarray
*
other
,
bool
unbroadcast
=
false
);
DllExport
int
CudaNdarray_CopyFromCudaNdarray
(
CudaNdarray
*
self
,
const
CudaNdarray
*
other
,
bool
unbroadcast
=
false
);
/**
/**
* Transfer the contents of CudaNdarray `self` to a new numpy ndarray.
* Transfer the contents of CudaNdarray `self` to a new numpy ndarray.
...
@@ -321,7 +321,7 @@ DllExport PyObject * CudaNdarray_IS_C_Contiguous(CudaNdarray * self);
...
@@ -321,7 +321,7 @@ DllExport PyObject * CudaNdarray_IS_C_Contiguous(CudaNdarray * self);
DllExport
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_sgemv
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_sgemv
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_sger
(
float
alpha
,
const
CudaNdarray
*
x
,
const
CudaNdarray
*
y
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_sum
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_sum
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_prod
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_prod
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论