Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a7cdae3f
提交
a7cdae3f
authored
5月 30, 2012
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #659 from nouiz/gpu_limit
Gpu limit
上级
c1ee7dd5
1a7a5b1c
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
39 行增加
和
11 行删除
+39
-11
blas.py
theano/sandbox/cuda/blas.py
+21
-10
test_blas.py
theano/sandbox/cuda/tests/test_blas.py
+18
-1
没有找到文件。
theano/sandbox/cuda/blas.py
浏览文件 @
a7cdae3f
...
...
@@ -843,7 +843,7 @@ class GpuDownsampleFactorMax(GpuOp):
#def perform(self, node, input_storage, output_storage):
#raise NotImplementedError('only C is implemented')
def
c_code_cache_version
(
self
):
return
(
4
)
return
(
5
)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
=
inp
...
...
@@ -896,7 +896,8 @@ class GpuDownsampleFactorMax(GpuOp):
}
}
{
dim3 grid(dims[0] * dims[1], dims[2]);
dim3 grid(std::min(dims[0] * dims[1], 65535),
dims[2]);
//dim3 block(std::min(dims[3], 512));
//TODO: implement this by supporting more outputs than threads
dim3 block(dims[3]);
...
...
@@ -943,8 +944,12 @@ class GpuDownsampleFactorMax(GpuOp):
float *z, int zS0, int zS1, int zS2, int zS3)
{
float cur_max, cur_x;
int i0 = blockIdx.x
%%
D0;
int i1 = blockIdx.x / D0;
for(int block_x_idx = blockIdx.x;
block_x_idx < D0 * D1;
block_x_idx += gridDim.x){
int i0 = block_x_idx
%%
D0;
int i1 = block_x_idx / D0;
int i2 = blockIdx.y;
extern __shared__ float xbuf[]; //size [xD3]
...
...
@@ -961,7 +966,8 @@ class GpuDownsampleFactorMax(GpuOp):
}
__syncthreads();
// initialize our max if this is the first row we're loading
// initialize our max if this is the
// first row we're loading
cur_max = (r2 == 0) ? xbuf[threadIdx.x*pf3] : cur_max;
// do a mini-reduction over the pf3 relevant elements
...
...
@@ -991,6 +997,7 @@ class GpuDownsampleFactorMax(GpuOp):
//store the result to global memory
z[i0*zS0 + i1*zS1 + i2*zS2 + threadIdx.x*zS3] = cur_max;
}
}
"""
%
locals
()
...
...
@@ -1019,8 +1026,7 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
return
Apply
(
self
,
[
x
,
z
,
gz
],
[
x
.
type
()])
def
c_code_cache_version
(
self
):
#return ()
return
(
5
,)
return
(
6
,)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
z
,
gz
=
inp
...
...
@@ -1062,7 +1068,8 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
// make sure we cover every x row when ignore border isset and
// there's a border present to be ignored
int needs_extra_z_col =
%(ignore_border)
s && (CudaNdarray_HOST_DIMS(
%(x)
s)[2]
%% %(ds0)
s);
dim3 grid(CudaNdarray_HOST_DIMS(
%(z)
s)[0],CudaNdarray_HOST_DIMS(
%(z)
s)[2] + (needs_extra_z_col ? 1 : 0));
dim3 grid(std::min(CudaNdarray_HOST_DIMS(
%(z)
s)[0], 65535),
CudaNdarray_HOST_DIMS(
%(z)
s)[2] + (needs_extra_z_col ? 1 : 0));
dim3 block(std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[3], 512));
kDownsampleMaxGrad_
%(nodename)
s<
%(ds0)
s,
%(ds1)
s> <<<grid, block>>>(
...
...
@@ -1136,7 +1143,10 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
// various .S. variables are strides
float cur_max, cur_x, my_z, my_gz;
int i0 = blockIdx.x; // image row
for(int i0 = blockIdx.x;
i0 < D0;
i0 += gridDim.x){
int i1 = 0; // image col
// row wrt z and/or gz, ranges from 0 to D2 - 1 OR D2
// (as needed to cover all x rows)
...
...
@@ -1200,9 +1210,10 @@ class GpuDownsampleFactorMaxGrad(GpuOp):
}
//gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 +
// x_row*xD3 + x_col] = -999;
}
}
}
}
}
}
"""
%
locals
()
theano/sandbox/cuda/tests/test_blas.py
浏览文件 @
a7cdae3f
import
copy
from
unittest
import
TestCase
from
theano.compile.pfunc
import
pfunc
...
...
@@ -32,6 +33,12 @@ else:
mode_with_gpu
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
'gpu'
)
mode_without_gpu
=
theano
.
compile
.
mode
.
get_default_mode
()
.
excluding
(
'gpu'
)
#The CPU tests already compare C/Py, so we only check C/GPU
mode_with_gpu
=
copy
.
copy
(
mode_with_gpu
)
mode_without_gpu
=
copy
.
copy
(
mode_without_gpu
)
mode_with_gpu
.
check_py_code
=
False
mode_without_gpu
.
check_py_code
=
False
def
my_rand
(
*
shape
):
return
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shape
),
dtype
=
'float32'
)
...
...
@@ -269,6 +276,8 @@ def test_downsample():
(
1
,
1
,
10
,
1023
),
(
1
,
1
,
1025
,
10
),
(
1
,
1
,
1023
,
10
),
(
65536
,
1
,
10
,
10
),
(
1
,
65536
,
10
,
10
),
]
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
.
shuffle
(
shps
)
...
...
@@ -299,6 +308,14 @@ def test_downsample():
for
node
in
f2
.
maker
.
env
.
toposort
()])
assert
numpy
.
allclose
(
f
(),
f2
())
# The grad is too slow on GT220 GPU
# This cause the computer to freeze...
# Remove this when it get optimized enought
# This only bypass the last 2 checks
# Those tests where passing in all Mode on a GTX470
if
shp
[
0
]
>
30000
or
shp
[
1
]
>
30000
:
continue
g
=
pfunc
(
[],
tensor
.
grad
(
ds_op
(
tensor
.
as_tensor_variable
(
a
))
.
sum
(),
...
...
@@ -314,7 +331,7 @@ def test_downsample():
for
node
in
g
.
maker
.
env
.
toposort
()])
assert
any
([
isinstance
(
node
.
op
,
DownsampleFactorMaxGrad
)
for
node
in
g2
.
maker
.
env
.
toposort
()])
assert
numpy
.
allclose
(
g
(),
g2
())
assert
numpy
.
allclose
(
g
(),
g2
())
,
shp
# We already check that the gpu version return
# the same value as the gpu version for
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论