Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
460e9561
提交
460e9561
authored
9月 20, 2012
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move fct and force them to be inlined.
上级
d0e0274f
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
106 行增加
和
115 行删除
+106
-115
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+0
-107
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+106
-8
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
460e9561
...
@@ -3133,24 +3133,6 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
...
@@ -3133,24 +3133,6 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
return
0
;
return
0
;
}
}
bool
CudaNdarray_is_c_contiguous
(
const
CudaNdarray
*
self
)
{
bool
c_contiguous
=
true
;
int
size
=
1
;
for
(
int
i
=
self
->
nd
-
1
;
(
i
>=
0
)
&&
c_contiguous
;
--
i
)
{
if
(
CudaNdarray_HOST_DIMS
(
self
)[
i
]
==
1
)
continue
;
if
(
CudaNdarray_HOST_STRIDES
(
self
)[
i
]
!=
size
)
{
c_contiguous
=
false
;
}
size
=
size
*
CudaNdarray_HOST_DIMS
(
self
)[
i
];
}
return
c_contiguous
;
}
PyObject
*
PyObject
*
CudaNdarray_new_nd
(
int
nd
)
CudaNdarray_new_nd
(
int
nd
)
{
{
...
@@ -4346,12 +4328,6 @@ CudaNdarray_HOST_LOG2DIMS(const CudaNdarray * self)
...
@@ -4346,12 +4328,6 @@ CudaNdarray_HOST_LOG2DIMS(const CudaNdarray * self)
return
self
->
host_structure
+
2
*
self
->
nd
;
return
self
->
host_structure
+
2
*
self
->
nd
;
}
}
void
cnda_mark_dev_structure_dirty
(
CudaNdarray
*
self
)
{
self
->
dev_structure_fresh
=
0
;
}
int
int
CudaNdarray_EqualAndIgnore
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
,
int
ignoreSync
,
int
ignoreBase
)
CudaNdarray_EqualAndIgnore
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
,
int
ignoreSync
,
int
ignoreBase
)
{
{
...
@@ -4406,39 +4382,6 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2)
...
@@ -4406,39 +4382,6 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2)
return
CudaNdarray_EqualAndIgnore
(
cnda1
,
cnda2
,
0
,
0
);
return
CudaNdarray_EqualAndIgnore
(
cnda1
,
cnda2
,
0
,
0
);
}
}
void
CudaNdarray_set_dim
(
CudaNdarray
*
self
,
int
idx
,
int
d
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
)
||
(
d
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_dim arguments: %i %i
\n
"
,
idx
,
d
);
}
if
(
d
!=
self
->
host_structure
[
idx
])
{
self
->
host_structure
[
idx
]
=
d
;
int
log2d
=
(
int
)
log2
((
double
)
d
);
self
->
host_structure
[
idx
+
2
*
self
->
nd
]
=
(
d
==
(
1
<<
log2d
))
?
log2d
:
-
1
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
void
CudaNdarray_set_stride
(
CudaNdarray
*
self
,
int
idx
,
int
s
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_stride arguments: %i %i
\n
"
,
idx
,
s
);
}
if
(
s
!=
CudaNdarray_HOST_STRIDES
(
self
)[
idx
])
{
self
->
host_structure
[
idx
+
self
->
nd
]
=
s
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
int
int
cnda_copy_structure_to_device
(
const
CudaNdarray
*
self
)
cnda_copy_structure_to_device
(
const
CudaNdarray
*
self
)
{
{
...
@@ -4510,56 +4453,6 @@ CudaNdarray_SIZE_Object(const CudaNdarray *self, void *closure)
...
@@ -4510,56 +4453,6 @@ CudaNdarray_SIZE_Object(const CudaNdarray *self, void *closure)
return
PyInt_FromLong
(
CudaNdarray_SIZE
(
self
));
return
PyInt_FromLong
(
CudaNdarray_SIZE
(
self
));
}
}
int
CudaNdarray_set_nd
(
CudaNdarray
*
self
,
const
int
nd
)
{
if
(
nd
!=
self
->
nd
)
{
if
(
self
->
dev_structure
)
{
if
(
device_free
(
self
->
dev_structure
))
{
return
-
1
;
}
self
->
dev_structure
=
NULL
;
}
if
(
self
->
host_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
nd
=
-
1
;
}
if
(
nd
==
-
1
)
return
0
;
self
->
host_structure
=
(
int
*
)
malloc
(
cnda_structure_size
(
nd
)
*
sizeof
(
int
));
if
(
NULL
==
self
->
host_structure
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Failed to allocate dim or str"
);
return
-
1
;
}
//initialize all dimensions and strides to 0
for
(
int
i
=
0
;
i
<
cnda_structure_size
(
nd
);
++
i
)
{
self
->
host_structure
[
i
]
=
0
;
}
int
struct_size
=
cnda_structure_size
(
nd
);
if
(
struct_size
)
{
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
));
if
(
NULL
==
self
->
dev_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
dev_structure
=
NULL
;
return
-
1
;
}
}
self
->
nd
=
nd
;
self
->
dev_structure_fresh
=
0
;
}
return
0
;
}
int
CudaNdarray_set_device_data
(
CudaNdarray
*
self
,
float
*
data
,
const
CudaNdarray
*
base
)
int
CudaNdarray_set_device_data
(
CudaNdarray
*
self
,
float
*
data
,
const
CudaNdarray
*
base
)
{
{
return
CudaNdarray_set_device_data
(
self
,
data
,
(
PyObject
*
)
base
);
return
CudaNdarray_set_device_data
(
self
,
data
,
(
PyObject
*
)
base
);
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
460e9561
...
@@ -126,8 +126,12 @@ CudaNdarray_HOST_STRIDES(const CudaNdarray * self);
...
@@ -126,8 +126,12 @@ CudaNdarray_HOST_STRIDES(const CudaNdarray * self);
DllExport
const
int
*
DllExport
const
int
*
CudaNdarray_HOST_LOG2DIMS
(
const
CudaNdarray
*
self
);
CudaNdarray_HOST_LOG2DIMS
(
const
CudaNdarray
*
self
);
DllExport
void
DllExport
inline
void
__attribute__
((
always_inline
))
cnda_mark_dev_structure_dirty
(
CudaNdarray
*
self
);
cnda_mark_dev_structure_dirty
(
CudaNdarray
*
self
)
{
self
->
dev_structure_fresh
=
0
;
}
DllExport
int
DllExport
int
CudaNdarray_EqualAndIgnore
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
,
int
ignoreSync
,
int
ignoreBase
);
CudaNdarray_EqualAndIgnore
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
,
int
ignoreSync
,
int
ignoreBase
);
...
@@ -143,11 +147,38 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2);
...
@@ -143,11 +147,38 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2);
*
*
* Does not sync structure to host.
* Does not sync structure to host.
*/
*/
DllExport
void
DllExport
inline
void
__attribute__
((
always_inline
))
CudaNdarray_set_dim
(
CudaNdarray
*
self
,
int
idx
,
int
d
);
CudaNdarray_set_dim
(
CudaNdarray
*
self
,
int
idx
,
int
d
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
)
||
(
d
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_dim arguments: %i %i
\n
"
,
idx
,
d
);
}
if
(
d
!=
self
->
host_structure
[
idx
])
{
self
->
host_structure
[
idx
]
=
d
;
int
log2d
=
(
int
)
log2
((
double
)
d
);
self
->
host_structure
[
idx
+
2
*
self
->
nd
]
=
(
d
==
(
1
<<
log2d
))
?
log2d
:
-
1
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
DllExport
void
CudaNdarray_set_stride
(
CudaNdarray
*
self
,
int
idx
,
int
s
);
DllExport
inline
void
__attribute__
((
always_inline
))
CudaNdarray_set_stride
(
CudaNdarray
*
self
,
int
idx
,
int
s
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_stride arguments: %i %i
\n
"
,
idx
,
s
);
}
if
(
s
!=
CudaNdarray_HOST_STRIDES
(
self
)[
idx
])
{
self
->
host_structure
[
idx
+
self
->
nd
]
=
s
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
/***
/***
* Update dependent variables from the contents of CudaNdarray_HOST_DIMS(self) and CudaNdarray_HOST_STRIDES(self)
* Update dependent variables from the contents of CudaNdarray_HOST_DIMS(self) and CudaNdarray_HOST_STRIDES(self)
...
@@ -188,7 +219,57 @@ DllExport PyObject * CudaNdarray_new_nd(const int nd);
...
@@ -188,7 +219,57 @@ DllExport PyObject * CudaNdarray_new_nd(const int nd);
*
*
* Note: This does not allocate storage for data.
* Note: This does not allocate storage for data.
*/
*/
DllExport
int
CudaNdarray_set_nd
(
CudaNdarray
*
self
,
const
int
nd
);
DllExport
inline
int
__attribute__
((
always_inline
))
CudaNdarray_set_nd
(
CudaNdarray
*
self
,
const
int
nd
)
{
if
(
nd
!=
self
->
nd
)
{
if
(
self
->
dev_structure
)
{
if
(
device_free
(
self
->
dev_structure
))
{
return
-
1
;
}
self
->
dev_structure
=
NULL
;
}
if
(
self
->
host_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
nd
=
-
1
;
}
if
(
nd
==
-
1
)
return
0
;
self
->
host_structure
=
(
int
*
)
malloc
(
cnda_structure_size
(
nd
)
*
sizeof
(
int
));
if
(
NULL
==
self
->
host_structure
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Failed to allocate dim or str"
);
return
-
1
;
}
//initialize all dimensions and strides to 0
for
(
int
i
=
0
;
i
<
cnda_structure_size
(
nd
);
++
i
)
{
self
->
host_structure
[
i
]
=
0
;
}
int
struct_size
=
cnda_structure_size
(
nd
);
if
(
struct_size
)
{
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
));
if
(
NULL
==
self
->
dev_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
dev_structure
=
NULL
;
return
-
1
;
}
}
self
->
nd
=
nd
;
self
->
dev_structure_fresh
=
0
;
}
return
0
;
}
/**
/**
* CudaNdarray_alloc_contiguous
* CudaNdarray_alloc_contiguous
...
@@ -333,7 +414,24 @@ CudaNdarray_ZEROS(int n, int * dims);
...
@@ -333,7 +414,24 @@ CudaNdarray_ZEROS(int n, int * dims);
/**
/**
* True iff the strides look like [dim[nd-2], dim[nd-3], ... , dim[0], 1]
* True iff the strides look like [dim[nd-2], dim[nd-3], ... , dim[0], 1]
*/
*/
DllExport
bool
CudaNdarray_is_c_contiguous
(
const
CudaNdarray
*
self
);
DllExport
inline
bool
__attribute__
((
always_inline
))
CudaNdarray_is_c_contiguous
(
const
CudaNdarray
*
self
)
{
bool
c_contiguous
=
true
;
int
size
=
1
;
for
(
int
i
=
self
->
nd
-
1
;
(
i
>=
0
)
&&
c_contiguous
;
--
i
)
{
if
(
CudaNdarray_HOST_DIMS
(
self
)[
i
]
==
1
)
continue
;
if
(
CudaNdarray_HOST_STRIDES
(
self
)[
i
]
!=
size
)
{
c_contiguous
=
false
;
}
size
=
size
*
CudaNdarray_HOST_DIMS
(
self
)[
i
];
}
return
c_contiguous
;
}
DllExport
PyObject
*
CudaNdarray_IS_C_Contiguous
(
CudaNdarray
*
self
);
DllExport
PyObject
*
CudaNdarray_IS_C_Contiguous
(
CudaNdarray
*
self
);
DllExport
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论