Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6550577b
提交
6550577b
authored
3月 22, 2012
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make gpu reshape reuse gpu copy that is faster.
This also fix an not understood crash in gpu reshape code when the input is not contiguous.
上级
5001ae40
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
29 行增加
和
91 行删除
+29
-91
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+29
-91
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
6550577b
...
@@ -567,38 +567,6 @@ PyObject * CudaNdarray_ReduceSum(CudaNdarray * self, PyObject * py_reduce_mask)
...
@@ -567,38 +567,6 @@ PyObject * CudaNdarray_ReduceSum(CudaNdarray * self, PyObject * py_reduce_mask)
return
(
PyObject
*
)
self_sum
;
return
(
PyObject
*
)
self_sum
;
}
}
__global__
void
k_copy_reshape_rowmajor
(
unsigned
int
numEls
,
unsigned
int
a_nd
,
const
float
*
a_data
,
const
int
*
a_dim
,
const
int
*
a_str
,
unsigned
int
z_nd
,
float
*
z_data
,
const
int
*
z_dim
,
const
int
*
z_str
)
{
const
unsigned
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
unsigned
int
numThreads
=
blockDim
.
x
*
gridDim
.
x
;
for
(
unsigned
int
i
=
idx
;
i
<
numEls
;
i
+=
numThreads
)
{
const
float
*
a_i
=
a_data
;
unsigned
int
a_ii
=
i
;
for
(
unsigned
int
_d
=
0
;
_d
<
a_nd
;
++
_d
)
//make the rightmost coords change fastest
{
unsigned
int
d
=
a_nd
-
_d
-
1
;
unsigned
int
a_i_d
=
a_ii
%
a_dim
[
d
];
a_ii
=
a_ii
/
a_dim
[
d
];
a_i
+=
a_i_d
*
a_str
[
d
];
}
unsigned
int
z_ii
=
i
;
float
*
z_i
=
z_data
;
for
(
unsigned
int
_d
=
0
;
_d
<
z_nd
;
++
_d
)
//make the rightmost coords change fastest
{
unsigned
int
d
=
z_nd
-
_d
-
1
;
//i tried to make the for loop count down, but it didn't work!?
unsigned
int
z_i_d
=
z_ii
%
z_dim
[
d
];
z_i
+=
z_i_d
*
z_str
[
d
];
z_ii
=
z_ii
/
z_dim
[
d
];
}
z_i
[
0
]
=
a_i
[
0
];
//copy one lousy float!
}
}
// Reshape self to the new shape gived by the tuple shape.
// Reshape self to the new shape gived by the tuple shape.
//
//
// If self is c contiguous, it return a view. Otherwise it always do a copy.
// If self is c contiguous, it return a view. Otherwise it always do a copy.
...
@@ -606,6 +574,22 @@ __global__ void k_copy_reshape_rowmajor(unsigned int numEls,
...
@@ -606,6 +574,22 @@ __global__ void k_copy_reshape_rowmajor(unsigned int numEls,
// c contiguous
// c contiguous
PyObject
*
CudaNdarray_Reshape
(
CudaNdarray
*
self
,
PyObject
*
shape
)
PyObject
*
CudaNdarray_Reshape
(
CudaNdarray
*
self
,
PyObject
*
shape
)
{
{
if
(
!
CudaNdarray_is_c_contiguous
(
self
))
{
// allocate new space
//TODO: test to see if we can re-use old one and take a new param to
// use this
CudaNdarray
*
rval
=
(
CudaNdarray
*
)
CudaNdarray_Copy
(
self
);
if
(
!
rval
)
{
return
NULL
;
}
CudaNdarray
*
ret
=
(
CudaNdarray
*
)
CudaNdarray_Reshape
(
rval
,
shape
);
Py_XDECREF
(
rval
);
return
(
PyObject
*
)
ret
;
}
// check shape tuple
// check shape tuple
unsigned
int
rval_nd
;
unsigned
int
rval_nd
;
unsigned
int
*
rval_dims
;
unsigned
int
*
rval_dims
;
...
@@ -656,75 +640,29 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
...
@@ -656,75 +640,29 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
return
rval
;
return
rval
;
}
}
if
(
CudaNdarray_is_c_contiguous
(
self
))
//return a view, not a copy
{
//we can do this as we checked self is c_contiguous
//return a view, not a copy
CudaNdarray
*
rval
=
(
CudaNdarray
*
)
CudaNdarray_New
(
rval_nd
);
CudaNdarray
*
rval
=
(
CudaNdarray
*
)
CudaNdarray_New
(
rval_nd
);
if
(
!
rval
||
0
!=
rval
->
data_allocated
||
CudaNdarray_set_device_data
(
rval
,
CudaNdarray_DEV_DATA
(
self
),
self
))
{
Py_XDECREF
(
rval
);
free
(
rval_dims
);
return
NULL
;
}
//set dim and stride
int
size
=
1
;
for
(
int
i
=
rval_nd
-
1
;
i
>=
0
;
--
i
)
{
CudaNdarray_set_stride
(
rval
,
i
,
(
rval_dims
[
i
]
==
1
)
?
0
:
size
);
CudaNdarray_set_dim
(
rval
,
i
,
rval_dims
[
i
]);
size
=
size
*
rval_dims
[
i
];
}
free
(
rval_dims
);
return
(
PyObject
*
)
rval
;
}
// allocate new space (TODO: test to see if we can re-use old one)
if
(
!
rval
||
0
!=
rval
->
data_allocated
CudaNdarray
*
rval
=
(
CudaNdarray
*
)
CudaNdarray_New
();
||
CudaNdarray_set_device_data
(
rval
,
CudaNdarray_DEV_DATA
(
self
),
self
))
if
(
!
rval
||
CudaNdarray_alloc_contiguous
(
rval
,
rval_nd
,
rval_dims
))
{
{
Py_XDECREF
(
rval
);
Py_XDECREF
(
rval
);
free
(
rval_dims
);
free
(
rval_dims
);
return
NULL
;
return
NULL
;
}
}
//set dim and stride
// call worker routine
int
size
=
1
;
unsigned
int
threads_per_block
=
std
::
min
(
rval_size
,
(
unsigned
int
)
NUM_VECTOR_OP_THREADS_PER_BLOCK
);
for
(
int
i
=
rval_nd
-
1
;
i
>=
0
;
--
i
)
unsigned
int
n_blocks
=
std
::
min
(
ceil_intdiv
(
rval_size
,
threads_per_block
),
(
unsigned
int
)
NUM_VECTOR_OP_BLOCKS
);
k_copy_reshape_rowmajor
<<<
n_blocks
,
threads_per_block
>>>
(
rval_size
,
self
->
nd
,
CudaNdarray_DEV_DATA
(
self
),
CudaNdarray_DEV_DIMS
(
self
),
CudaNdarray_DEV_STRIDES
(
self
),
rval
->
nd
,
CudaNdarray_DEV_DATA
(
rval
),
CudaNdarray_DEV_DIMS
(
rval
),
CudaNdarray_DEV_STRIDES
(
rval
));
CNDA_THREAD_SYNC
;
cudaError_t
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
{
Py_DECREF
(
rval
);
CudaNdarray_set_stride
(
rval
,
i
,
(
rval_dims
[
i
]
==
1
)
?
0
:
size
);
PyObject
*
shape_inp
=
CudaNdarray_get_shape
(
self
,
NULL
);
CudaNdarray_set_dim
(
rval
,
i
,
rval_dims
[
i
]);
PyObject
*
shape_inp2
=
PyObject_Str
(
shape_inp
);
size
=
size
*
rval_dims
[
i
];
PyObject
*
shape_dest
=
PyObject_Str
(
shape
);
PyErr_Format
(
PyExc_RuntimeError
,
"Cuda error in CudaNdarray_Reshape"
"()n_blocks=%d, n_threads=%d, input_shape=%s,"
" dest_shape=%s): %s: %s.
\n
"
,
n_blocks
,
threads_per_block
,
PyString_AsString
(
shape_inp2
),
PyString_AsString
(
shape_dest
),
"k_copy_reshape_rowmajor"
,
cudaGetErrorString
(
err
)
);
Py_DECREF
(
shape_dest
);
Py_DECREF
(
shape_inp
);
Py_DECREF
(
shape_inp2
);
free
(
rval_dims
);
return
NULL
;
}
}
free
(
rval_dims
);
free
(
rval_dims
);
return
(
PyObject
*
)
rval
;
return
(
PyObject
*
)
rval
;
}
}
PyObject
*
CudaNdarray_View
(
CudaNdarray
*
self
)
PyObject
*
CudaNdarray_View
(
CudaNdarray
*
self
)
{
{
CudaNdarray
*
rval
=
(
CudaNdarray
*
)
CudaNdarray_New
(
self
->
nd
);
CudaNdarray
*
rval
=
(
CudaNdarray
*
)
CudaNdarray_New
(
self
->
nd
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论