Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
894d6665
提交
894d6665
authored
6月 29, 2012
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make GpuAdvancedSubtensor1 use int64 for indices to make sure we support all index number.
上级
e6b2160a
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
30 行增加
和
15 行删除
+30
-15
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+15
-3
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+15
-12
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
894d6665
...
...
@@ -1891,6 +1891,8 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
"""
Implement AdvancedSubtensor1 on the gpu.
"""
#If True or False, we assert that we use the take version or not
#If None, we choose the best one applicable
perform_using_take
=
None
def
make_node
(
self
,
x
,
ilist
):
...
...
@@ -1910,8 +1912,9 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
#super(GpuAdvancedSubtensor1, self).perform(node, inp, out_)
x
,
idx
=
inp
out
,
=
out_
#TODO: if more then 3 dims, reshape the inputs if it is contiguous.
x_orig
=
x
#TODO: if more then 3 dims, reshape the inputs even if not all
#dimensions are c contiguous
if
x
.
ndim
>
3
and
x
.
is_c_contiguous
():
x
=
x
.
reshape
((
x
.
shape
[
0
],
numpy
.
prod
(
x
.
shape
[
1
:])))
out_shape
=
(
len
(
idx
),)
+
x_orig
.
shape
[
1
:]
...
...
@@ -1920,8 +1923,17 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
if
self
.
perform_using_take
is
not
None
:
assert
self
.
perform_using_take
==
True
,
(
"GpuAdvancedSubtensor1 used the fast version"
)
o
=
x
.
take
(
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
(
idx
.
astype
(
"float32"
)),
# idx
if
idx
.
dtype
!=
numpy
.
int64
:
if
idx
.
dtype
in
[
numpy
.
int8
,
numpyt
.
int16
,
numpy
.
int32
,
numpy
.
int64
,
numpy
.
uint8
,
numpy
.
uint16
,
numpy
.
uint32
]:
idx
=
idx
.
astype
(
numpy
.
int64
)
if
not
idx
.
flags
.
c_contiguous
:
idx
=
numpy
.
ascontiguousarray
(
idx
)
idx
=
idx
.
view
(
"float32"
)
idx
=
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
(
idx
)
o
=
x
.
take
(
idx
,
0
,
# axis
out_
[
0
][
0
])
# return
if
x
is
not
x_orig
:
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
894d6665
...
...
@@ -701,14 +701,14 @@ enum operator_t
*/
template
<
int
operator_num
>
__global__
void
k_take_3
(
const
int
d0
,
const
int
d1
,
const
int
d2
,
const
float
*
indices
,
const
npy_int64
*
indices
,
float
*
a
,
const
int
sA0
,
const
int
sA1
,
const
int
sA2
,
const
float
*
b
,
const
int
dB0
,
const
int
sB0
,
const
int
sB1
,
const
int
sB2
,
int
*
err
){
for
(
int
i0
=
blockIdx
.
x
;
i0
<
d0
;
i0
+=
gridDim
.
x
){
int
idx
=
(
int
)
indices
[
i0
];
npy_int64
idx
=
indices
[
i0
];
if
(
idx
<
0
)
idx
+=
dB0
;
// To allow negative indexing.
if
((
idx
<
0
)
||
(
idx
>=
dB0
))
...
...
@@ -737,8 +737,9 @@ static int* err_var = NULL;
// We try to be similat to the PyArray_TakeFrom function
//http://docs.scipy.org/doc/numpy/reference/c-api.array.html
//TODO: support other clip mode then raise(clip, wrap)
//TODO: what if the indices take more then 32 bits?
//self is the input that we copy data from.
//The indices that we receive MUST be an CudaNdarray(float32)
// that is in fact a view to int64 indices
PyObject
*
CudaNdarray_TakeFrom
(
CudaNdarray
*
self
,
PyObject
*
args
){
int
verbose
=
0
;
...
...
@@ -761,7 +762,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
if
(
verbose
)
printf
(
"cudandarray indices
\n
"
);
indices
=
(
CudaNdarray
*
)
indices_obj
;
Py_INCREF
(
indices
);
}
else
if
(
PyArray_Check
(
indices_obj
))
{
}
else
if
(
0
&&
PyArray_Check
(
indices_obj
))
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: The indices must cudandarray with float32 value."
);
return
NULL
;
...
...
@@ -800,9 +801,10 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
return
NULL
;
}
Py_DECREF
(
indices_float32
);
}
else
{
PyErr_SetString
(
PyExc_TypeError
,
"CudaNdarray_TakeFrom: need a CudaNdarray for indices"
);
PyErr_SetString
(
PyExc_TypeError
,
"CudaNdarray_TakeFrom: need a CudaNdarray(float32) that"
" is a view from int64 data for indices"
);
return
NULL
;
}
...
...
@@ -815,11 +817,12 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
}
if
(
verbose
)
printf
(
"after print of object
\n
"
);
if
(
!
CudaNdarray_is_c_contiguous
(
indices
)
!=
0
)
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: The indices must be contiguous in memory."
);
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: The indices must be contiguous in memory."
);
Py_DECREF
(
indices_obj
);
return
NULL
;
}
int
nb_indices
=
CudaNdarray_SIZE
((
CudaNdarray
*
)
indices
)
;
int
nb_indices
=
CudaNdarray_SIZE
((
CudaNdarray
*
)
indices
)
/
2
;
// int64 are 8 bytes, float32 are 4 bytes
//Check argument axis
//TODO: implement the default and other axis
...
...
@@ -885,7 +888,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
Py_DECREF
(
clipmode_obj
);
}
void
(
*
k3
)(
const
int
,
const
int
,
const
int
,
const
float
*
,
const
npy_int64
*
,
float
*
,
const
int
,
const
int
,
const
int
,
const
float
*
,
const
int
,
const
int
,
const
int
,
const
int
,
...
...
@@ -923,7 +926,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
dims
[
0
],
1
,
1
,
CudaNdarray_DEV_DATA
(
indices
),
(
npy_int64
*
)
CudaNdarray_DEV_DATA
(
indices
),
CudaNdarray_DEV_DATA
(
out
),
CudaNdarray_HOST_STRIDES
(
out
)[
0
],
//strides
1
,
...
...
@@ -947,7 +950,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
dims
[
0
],
//dimensions
dims
[
1
],
1
,
CudaNdarray_DEV_DATA
(
indices
),
(
npy_int64
*
)
CudaNdarray_DEV_DATA
(
indices
),
CudaNdarray_DEV_DATA
(
out
),
CudaNdarray_HOST_STRIDES
(
out
)[
0
],
//strides
CudaNdarray_HOST_STRIDES
(
out
)[
1
],
...
...
@@ -973,7 +976,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
dims
[
0
],
//dimensions
dims
[
1
],
dims
[
2
],
CudaNdarray_DEV_DATA
(
indices
),
(
npy_int64
*
)
CudaNdarray_DEV_DATA
(
indices
),
CudaNdarray_DEV_DATA
(
out
),
CudaNdarray_HOST_STRIDES
(
out
)[
0
],
//strides
CudaNdarray_HOST_STRIDES
(
out
)[
1
],
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论