Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
771d5c27
提交
771d5c27
authored
6月 25, 2014
作者:
Marc-Alexandre Cote
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added unit tests and comments.
上级
2ad36935
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
52 行增加
和
30 行删除
+52
-30
extra_ops.py
theano/sandbox/cuda/extra_ops.py
+30
-14
test_extra_ops.py
theano/sandbox/cuda/tests/test_extra_ops.py
+22
-16
没有找到文件。
theano/sandbox/cuda/extra_ops.py
浏览文件 @
771d5c27
...
@@ -9,7 +9,7 @@ from theano.sandbox.cuda import GpuFlatten
...
@@ -9,7 +9,7 @@ from theano.sandbox.cuda import GpuFlatten
if
cuda_available
:
if
cuda_available
:
from
theano.sandbox.cuda
import
CudaNdarrayType
from
theano.sandbox.cuda
import
CudaNdarrayType
from
theano.sandbox.cuda.basic_ops
import
host_from_gpu
,
gpu_from_host
from
theano.sandbox.cuda.basic_ops
import
host_from_gpu
,
gpu_from_host
,
HostFromGpu
from
theano.sandbox.cuda.opt
import
register_opt
as
register_gpu_opt
from
theano.sandbox.cuda.opt
import
register_opt
as
register_gpu_opt
...
@@ -78,12 +78,13 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -78,12 +78,13 @@ class GpuCumsum(CumsumOp, GpuOp):
compute_map
,
no_recycling
)
compute_map
,
no_recycling
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
3
,)
return
(
4
,)
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
"""
return
"""
__device__
__device__
void k_reductionPhase_
%(nodename)
s(float* partialCumSum) {
void k_reductionPhase_
%(nodename)
s(float* partialCumSum) {
// Traverse down from leaves to root building partial sums at internal nodes in the tree.
for (unsigned int stride = 1; stride <= blockDim.x; stride *= 2) {
for (unsigned int stride = 1; stride <= blockDim.x; stride *= 2) {
__syncthreads();
__syncthreads();
unsigned int index = (threadIdx.x + 1) * (stride * 2) - 1;
unsigned int index = (threadIdx.x + 1) * (stride * 2) - 1;
...
@@ -95,6 +96,7 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -95,6 +96,7 @@ class GpuCumsum(CumsumOp, GpuOp):
__device__
__device__
void k_reversePhase_
%(nodename)
s(float* partialCumSum) {
void k_reversePhase_
%(nodename)
s(float* partialCumSum) {
// Traverse back up the tree building the scan from the partial sums
for (unsigned int stride = exp2(ceil(log2((float)blockDim.x))); stride > 0; stride /= 2) {
for (unsigned int stride = exp2(ceil(log2((float)blockDim.x))); stride > 0; stride /= 2) {
__syncthreads();
__syncthreads();
unsigned int index = (threadIdx.x + 1) * (stride * 2) - 1;
unsigned int index = (threadIdx.x + 1) * (stride * 2) - 1;
...
@@ -151,7 +153,7 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -151,7 +153,7 @@ class GpuCumsum(CumsumOp, GpuOp):
__global__
__global__
void k_blockCumSum_
%(nodename)
s(float* input, float* output, int numElements, dim3 dataStrides, int dataOffset, float* blockSum) {
void k_blockCumSum_
%(nodename)
s(float* input, float* output, int numElements, dim3 dataStrides, int dataOffset, float* blockSum) {
// Regarding blockIdx and threadIdx, 'Cumsum' is always perform along the X axis.
// Regarding blockIdx and threadIdx, 'Cumsum' is always perform
ed
along the X axis.
// The Y axis will contain all the independent cumsums of the 2D case.
// The Y axis will contain all the independent cumsums of the 2D case.
int globalThreadID = blockIdx.x * blockDim.x + threadIdx.x;
int globalThreadID = blockIdx.x * blockDim.x + threadIdx.x;
...
@@ -166,6 +168,9 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -166,6 +168,9 @@ class GpuCumsum(CumsumOp, GpuOp):
// Load data in shared memory
// Load data in shared memory
k_fetchData_
%(nodename)
s(partialCumSum, input, globalThreadID, dataStrides, dataOffset);
k_fetchData_
%(nodename)
s(partialCumSum, input, globalThreadID, dataStrides, dataOffset);
// Use a dichotomy approach to compute the cumsum (i.e. balanced binary tree).
// The tree is sweeped from the leaves to the root and from the root to the leaves.
// Similar to http://www.umiacs.umd.edu/~ramani/cmsc828e_gpusci/ScanTalk.pdf
k_reductionPhase_
%(nodename)
s(partialCumSum);
k_reductionPhase_
%(nodename)
s(partialCumSum);
k_reversePhase_
%(nodename)
s(partialCumSum);
k_reversePhase_
%(nodename)
s(partialCumSum);
...
@@ -179,7 +184,7 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -179,7 +184,7 @@ class GpuCumsum(CumsumOp, GpuOp):
}
}
}
}
void
cumSum_
%(nodename)
s(CudaNdarray* input, CudaNdarray* output, int maxThreads, int axis, int maxGridY) {
int
cumSum_
%(nodename)
s(CudaNdarray* input, CudaNdarray* output, int maxThreads, int axis, int maxGridY) {
int shape[2] = { 1, 1 };
int shape[2] = { 1, 1 };
dim3 dataStrides(0,0,0);
dim3 dataStrides(0,0,0);
...
@@ -195,12 +200,14 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -195,12 +200,14 @@ class GpuCumsum(CumsumOp, GpuOp):
dataStrides.x = CudaNdarray_HOST_STRIDES(input)[0];
dataStrides.x = CudaNdarray_HOST_STRIDES(input)[0];
dataStrides.y = CudaNdarray_HOST_STRIDES(input)[1];
dataStrides.y = CudaNdarray_HOST_STRIDES(input)[1];
break;
break;
default: printf("Only 1D and 2D cumsum is implemented yet.
\\
n");
default:
printf("Only 1D and 2D cumsum is implemented yet.
\\
n");
return -1;
}
}
if (shape[axis] <= 1) {
if (shape[axis] <= 1) {
CudaNdarray_CopyFromCudaNdarray(output, input);
CudaNdarray_CopyFromCudaNdarray(output, input);
return;
return
0
;
}
}
if (axis == 1) {
if (axis == 1) {
...
@@ -215,8 +222,7 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -215,8 +222,7 @@ class GpuCumsum(CumsumOp, GpuOp):
int dimGridY = shape[1-axis]; // Nb. of independent cumsums.
int dimGridY = shape[1-axis]; // Nb. of independent cumsums.
const int shapeBlockSum[2] = { dimGridX, dimGridY };
const int shapeBlockSum[2] = { dimGridX, dimGridY };
//CudaNdarray* deviceBlockSum = (CudaNdarray*) CudaNdarray_NewDims(2, shapeBlockSum);
CudaNdarray* deviceBlockSum = (CudaNdarray*) CudaNdarray_NewDims(2, shapeBlockSum);
CudaNdarray* deviceBlockSum = (CudaNdarray*) CudaNdarray_ZEROS(2, (int*)shapeBlockSum);
for (int dataOffset = 0; dataOffset < dimGridY; dataOffset += maxGridY){
for (int dataOffset = 0; dataOffset < dimGridY; dataOffset += maxGridY){
int localDimGridY = min(dimGridY - dataOffset, maxGridY);
int localDimGridY = min(dimGridY - dataOffset, maxGridY);
...
@@ -224,7 +230,6 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -224,7 +230,6 @@ class GpuCumsum(CumsumOp, GpuOp):
dim3 dimGrid(dimGridX, localDimGridY, 1);
dim3 dimGrid(dimGridX, localDimGridY, 1);
int sharedBytes = (2*blockSize) * sizeof(float);
int sharedBytes = (2*blockSize) * sizeof(float);
cudaThreadSynchronize();
k_blockCumSum_
%(nodename)
s<<<dimGrid, dimBlock, sharedBytes>>>
k_blockCumSum_
%(nodename)
s<<<dimGrid, dimBlock, sharedBytes>>>
(
(
CudaNdarray_DEV_DATA(input),
CudaNdarray_DEV_DATA(input),
...
@@ -237,8 +242,12 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -237,8 +242,12 @@ class GpuCumsum(CumsumOp, GpuOp):
if (dimGridX > 1) {
if (dimGridX > 1) {
// Do a cumsum over the blockSum (recursive).
// Do a cumsum over the blockSum (recursive).
cumSum_
%(nodename)
s(deviceBlockSum, deviceBlockSum, maxThreads, 0, maxGridY);
if (cumSum_
%(nodename)
s(deviceBlockSum, deviceBlockSum, maxThreads, 0, maxGridY) == -1){
return -1;
}
// Since there are more than one block (i.e. `dimGridX > 1`)
// report partial cumsums of previous blocks to subsequents ones.
dim3 dimGrid(dimGridX, dimGridY, 1);
dim3 dimGrid(dimGridX, dimGridY, 1);
dim3 dimBlock(blockSize, 1, 1);
dim3 dimBlock(blockSize, 1, 1);
k_finalCumSum_
%(nodename)
s<<<dimGrid, dimBlock>>>
k_finalCumSum_
%(nodename)
s<<<dimGrid, dimBlock>>>
...
@@ -253,7 +262,7 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -253,7 +262,7 @@ class GpuCumsum(CumsumOp, GpuOp):
// If shape[axis] is odd, the last element is compute manually
// If shape[axis] is odd, the last element is compute manually
if (shape[axis] != numElements){
if (shape[axis] != numElements){
cudaThreadSynchronize()
;
CNDA_THREAD_SYNC
;
dim3 dimGrid(1, localDimGridY, 1);
dim3 dimGrid(1, localDimGridY, 1);
dim3 dimBlock(1, 1, 1);
dim3 dimBlock(1, 1, 1);
k_cumadd_
%(nodename)
s<<<dimGrid, dimBlock>>>
k_cumadd_
%(nodename)
s<<<dimGrid, dimBlock>>>
...
@@ -269,7 +278,8 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -269,7 +278,8 @@ class GpuCumsum(CumsumOp, GpuOp):
}
}
cudaFree(CudaNdarray_DEV_DATA(deviceBlockSum));
cudaFree(CudaNdarray_DEV_DATA(deviceBlockSum));
cudaThreadSynchronize();
CNDA_THREAD_SYNC;
return 0;
}
}
"""
%
locals
()
"""
%
locals
()
...
@@ -321,7 +331,9 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -321,7 +331,9 @@ class GpuCumsum(CumsumOp, GpuOp):
}
}
{ // Namespace for kernel calls //
{ // Namespace for kernel calls //
cumSum_
%(nodename)
s(
%(x)
s,
%(z)
s,
%(max_threads_dim0)
s,
%(axis)
s,
%(max_grid_size1)
s);
if (cumSum_
%(nodename)
s(
%(x)
s,
%(z)
s,
%(max_threads_dim0)
s,
%(axis)
s,
%(max_grid_size1)
s) == -1){
%(fail)
s;
}
cudaError_t sts = cudaGetLastError();
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
if (cudaSuccess != sts)
...
@@ -340,7 +352,11 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -340,7 +352,11 @@ class GpuCumsum(CumsumOp, GpuOp):
@local_optimizer
([
CumsumOp
])
@local_optimizer
([
CumsumOp
])
def
use_gpu_cumsum
(
node
):
def
use_gpu_cumsum
(
node
):
if
type
(
node
.
op
)
is
CumsumOp
and
node
.
inputs
[
0
]
.
dtype
==
'float32'
:
if
type
(
node
.
op
)
is
CumsumOp
\
and
node
.
inputs
[
0
]
.
dtype
==
'float32'
\
and
node
.
inputs
[
0
]
.
owner
\
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
HostFromGpu
):
axis
=
node
.
op
.
axis
axis
=
node
.
op
.
axis
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
...
theano/sandbox/cuda/tests/test_extra_ops.py
浏览文件 @
771d5c27
...
@@ -17,12 +17,11 @@ from theano import tensor as T
...
@@ -17,12 +17,11 @@ from theano import tensor as T
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
from
theano
import
config
from
theano
import
config
from
theano.tensor.extra_ops
import
cumsum
from
theano.tensor.extra_ops
import
cumsum
,
CumsumOp
class
TestGpuCumsum
(
theano
.
tensor
.
tests
.
test_extra_ops
.
TestCumsumOp
):
class
TestGpuCumsum
(
theano
.
tensor
.
tests
.
test_extra_ops
.
TestCumsumOp
):
mode
=
mode_with_gpu
mode
=
mode_with_gpu
op
=
GpuCumsum
def
setUp
(
self
):
def
setUp
(
self
):
super
(
TestGpuCumsum
,
self
)
.
setUp
()
super
(
TestGpuCumsum
,
self
)
.
setUp
()
...
@@ -68,8 +67,8 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
...
@@ -68,8 +67,8 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
assert
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
assert
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
GpuCumsum
)]
if
isinstance
(
n
.
op
,
GpuCumsum
)]
# Extensive testing for the first 1
k
sizes
# Extensive testing for the first 1
025
sizes
a
=
np
.
ones
((
int
(
1e3
),),
dtype
=
"float32"
)
a
=
np
.
random
.
random
(
1025
)
.
astype
(
"float32"
)
for
i
in
xrange
(
a
.
shape
[
0
]):
for
i
in
xrange
(
a
.
shape
[
0
]):
assert
np
.
allclose
(
np
.
cumsum
(
a
[:
i
]),
f
(
a
[:
i
]))
assert
np
.
allclose
(
np
.
cumsum
(
a
[:
i
]),
f
(
a
[:
i
]))
...
@@ -86,36 +85,43 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
...
@@ -86,36 +85,43 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
block_max_size
=
self
.
max_threads_dim0
*
2
block_max_size
=
self
.
max_threads_dim0
*
2
x
=
T
.
fmatrix
(
'x'
)
x
=
T
.
fmatrix
(
'x'
)
for
axis
in
xrange
(
2
):
for
shape_axis
,
axis
in
zip
([
0
,
1
,
0
],
[
0
,
1
,
None
]
):
f
=
theano
.
function
([
x
],
cumsum
(
x
,
axis
=
axis
),
mode
=
self
.
mode
)
f
=
theano
.
function
([
x
],
cumsum
(
x
,
axis
=
axis
),
mode
=
self
.
mode
)
assert
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
assert
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
GpuCumsum
)]
if
isinstance
(
n
.
op
,
GpuCumsum
)]
# Extensive testing for the first 1
k
sizes
# Extensive testing for the first 1
025
sizes
a_shape
=
[
5
,
5
]
a_shape
=
[
5
,
5
]
a_shape
[
axis
]
=
int
(
1e3
)
a_shape
[
shape_axis
]
=
1025
a
=
np
.
ones
(
a_shape
,
dtype
=
"float32"
)
a
=
np
.
random
.
random
(
a_shape
)
.
astype
(
"float32"
)
slices
=
[
slice
(
None
),
slice
(
None
)]
slices
=
[
slice
(
None
),
slice
(
None
)]
for
i
in
xrange
(
a
.
shape
[
axis
]):
for
i
in
xrange
(
a
.
shape
[
shape_
axis
]):
slices
[
axis
]
=
slice
(
i
)
slices
[
shape_
axis
]
=
slice
(
i
)
fa
=
f
(
a
[
slices
])
fa
=
f
(
a
[
slices
])
npa
=
np
.
cumsum
(
a
[
slices
],
axis
=
axis
)
npa
=
np
.
cumsum
(
a
[
slices
],
axis
=
axis
)
assert
np
.
allclose
(
npa
,
fa
)
assert
np
.
allclose
(
npa
,
fa
)
# Use multiple GPU threadblocks
# Use multiple GPU threadblocks
a_shape
=
[
5
,
5
]
a_shape
=
[
5
,
5
]
a_shape
[
axis
]
=
block_max_size
+
2
a_shape
[
shape_
axis
]
=
block_max_size
+
2
a
=
np
.
ones
(
a_shape
,
dtype
=
"float32"
)
a
=
np
.
random
.
random
(
a_shape
)
.
astype
(
"float32"
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
# Use multiple GPU gridblocks
# Use multiple GPU gridblocks
a_shape
=
[
5
,
5
]
a_shape
=
[
5
,
5
]
a_shape
[
1
-
axis
]
=
self
.
max_grid_size1
+
1
a_shape
[
1
-
shape_
axis
]
=
self
.
max_grid_size1
+
1
a
=
np
.
ones
(
a_shape
,
dtype
=
"float32"
)
a
=
np
.
random
.
random
(
a_shape
)
.
astype
(
"float32"
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
# Use recursive cumsum
# Use recursive cumsum
a_shape
=
[
5
,
5
]
a_shape
=
[
5
,
3
]
a_shape
[
axis
]
=
block_max_size
*
(
block_max_size
+
1
)
+
2
a_shape
[
shape_
axis
]
=
block_max_size
*
(
block_max_size
+
1
)
+
2
a
=
np
.
ones
(
a_shape
,
dtype
=
"float32"
)
a
=
np
.
ones
(
a_shape
,
dtype
=
"float32"
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
def
test_GpuCumsum3D
(
self
):
# Should not use the GPU version.
x
=
T
.
ftensor3
(
'x'
)
f
=
theano
.
function
([
x
],
cumsum
(
x
,
axis
=
1
),
mode
=
self
.
mode
)
assert
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
CumsumOp
)]
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论