Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ebad678c
提交
ebad678c
authored
5月 13, 2011
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow GpuDownsampleFactorMaxGrad to work with more then 512 columns in its outputs.
上级
577aee4a
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
52 行增加
和
28 行删除
+52
-28
blas.py
theano/sandbox/cuda/blas.py
+39
-26
test_blas.py
theano/sandbox/cuda/tests/test_blas.py
+13
-2
没有找到文件。
theano/sandbox/cuda/blas.py
浏览文件 @
ebad678c
...
@@ -588,7 +588,7 @@ class GpuDownsampleFactorMaxGrad(Op):
...
@@ -588,7 +588,7 @@ class GpuDownsampleFactorMaxGrad(Op):
return
Apply
(
self
,
[
x
,
z
,
gz
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
,
z
,
gz
],
[
x
.
type
()])
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
#return ()
#return ()
return
(
3
,)
return
(
4
,)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
z
,
gz
=
inp
x
,
z
,
gz
=
inp
...
@@ -625,7 +625,8 @@ class GpuDownsampleFactorMaxGrad(Op):
...
@@ -625,7 +625,8 @@ class GpuDownsampleFactorMaxGrad(Op):
// make sure we cover every x row when ignore border isset and there's a border present to be ignored
// make sure we cover every x row when ignore border isset and there's a border present to be ignored
int needs_extra_z_col =
%(ignore_border)
s && (CudaNdarray_HOST_DIMS(
%(x)
s)[2]
%% %(ds0)
s);
int needs_extra_z_col =
%(ignore_border)
s && (CudaNdarray_HOST_DIMS(
%(x)
s)[2]
%% %(ds0)
s);
dim3 grid(CudaNdarray_HOST_DIMS(
%(z)
s)[0],CudaNdarray_HOST_DIMS(
%(z)
s)[2] + (needs_extra_z_col ? 1 : 0));
dim3 grid(CudaNdarray_HOST_DIMS(
%(z)
s)[0],CudaNdarray_HOST_DIMS(
%(z)
s)[2] + (needs_extra_z_col ? 1 : 0));
dim3 block(CudaNdarray_HOST_DIMS(
%(x)
s)[3]);
dim3 block(std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[3], 512));
kDownsampleMaxGrad_
%(nodename)
s<
%(ds0)
s,
%(ds1)
s> <<<grid, block>>>(
kDownsampleMaxGrad_
%(nodename)
s<
%(ds0)
s,
%(ds1)
s> <<<grid, block>>>(
CudaNdarray_HOST_DIMS(
%(z)
s)[0],
CudaNdarray_HOST_DIMS(
%(z)
s)[0],
CudaNdarray_HOST_DIMS(
%(z)
s)[1],
CudaNdarray_HOST_DIMS(
%(z)
s)[1],
...
@@ -705,32 +706,44 @@ class GpuDownsampleFactorMaxGrad(Op):
...
@@ -705,32 +706,44 @@ class GpuDownsampleFactorMaxGrad(Op):
for (i1 = 0; i1 < D1; ++i1) // loop over images (same for z and x)
for (i1 = 0; i1 < D1; ++i1) // loop over images (same for z and x)
{
{
if (
%(ignore_border)
s && x_col >= ds1 * D3)
for(int col_iter = 0; col_iter * blockDim.x <= xD3 ; col_iter++){
{
//The if inside is to don't do the division if we need only 1 col_iter
// This happens only if x_col was ignored (via ignore_border)
if(blockDim.x != xD3)
// TODO: if ignore_border is False, this is impossible and we don't even
{
// need to generate this code.
x_col = threadIdx.x + col_iter * blockDim.x;
z_col = x_col/ds1;
}
my_gz = 0.0f;
if (
%(ignore_border)
s && x_col >= ds1 * D3)
//any fp number suffices for my_z, so we don't even need to set it to
{
//anything in particular.
// This happens only if x_col was ignored (via ignore_border)
}
// TODO: if ignore_border is False, this is impossible and we don't even
else
// need to generate this code.
{
// this is effectively:
my_gz = 0.0f;
// my_gz = gz[image_row][image_col][z_row][z_col]
//any fp number suffices for my_z, so we don't even need to set it to
// my_z = z[image_row][image_col][z_row][z_col]
//anything in particular.
my_gz = gz[i0 * gzS0 + i1 * gzS1 + i2 * gzS2 + z_col*gzS3];
}
my_z = z[i0 * zS0 + i1 * zS1 + i2 * zS2 + z_col* zS3];
else
}
{
// this is effectively:
// my_gz = gz[image_row][image_col][z_row][z_col]
// my_z = z[image_row][image_col][z_row][z_col]
my_gz = gz[i0 * gzS0 + i1 * gzS1 + i2 * gzS2 + z_col*gzS3];
my_z = z[i0 * zS0 + i1 * zS1 + i2 * zS2 + z_col* zS3];
}
if(x_col<xD3){
for (int x_row = i2*ds0; (x_row < i2*ds0+ds0) && (x_row < xD2); ++x_row)
{
// this is effectively:
// gx[image_row][image_col][x_row][x_col]
// = (my_z == x[image_row][image_col][x_row][x_col]) ? my_gz : 0.0f;
gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 + x_row*xD3 + x_col]
= (my_z == x[i0*xS0 + i1*xS1 + x_row*xS2 + x_col*xS3]) ? my_gz : 0.0f;
}
//gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 + x_row*xD3 + x_col] = -999;
}
for (int x_row = i2*ds0; (x_row < i2*ds0+ds0) && (x_row < xD2); ++x_row)
{
// this is effectively:
// gx[image_row][image_col][x_row][x_col]
// = (my_z == x[image_row][image_col][x_row][x_col]) ? my_gz : 0.0f;
gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 + x_row*xD3 + x_col]
= (my_z == x[i0*xS0 + i1*xS1 + x_row*xS2 + x_col*xS3]) ? my_gz : 0.0f;
}
}
}
}
}
}
...
...
theano/sandbox/cuda/tests/test_blas.py
浏览文件 @
ebad678c
...
@@ -12,7 +12,7 @@ if cuda_ndarray.cuda_available == False:
...
@@ -12,7 +12,7 @@ if cuda_ndarray.cuda_available == False:
import
theano.sandbox.cuda
as
tcn
import
theano.sandbox.cuda
as
tcn
from
theano.tensor.signal.downsample
import
DownsampleFactorMax
from
theano.tensor.signal.downsample
import
DownsampleFactorMax
,
DownsampleFactorMaxGrad
import
theano.compile.mode
import
theano.compile.mode
...
@@ -163,7 +163,12 @@ def test_downsample():
...
@@ -163,7 +163,12 @@ def test_downsample():
(
30
,
6
,
12
,
12
),
(
30
,
6
,
12
,
12
),
(
30
,
2
,
24
,
24
),
(
30
,
2
,
24
,
24
),
(
30
,
6
,
24
,
24
),
(
30
,
6
,
24
,
24
),
(
10
,
10
,
10
,
11
)]
(
10
,
10
,
10
,
11
),
(
1
,
1
,
10
,
1025
),
(
1
,
1
,
10
,
1023
),
(
1
,
1
,
1025
,
10
),
(
1
,
1
,
1023
,
10
),
]
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
.
shuffle
(
shps
)
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
.
shuffle
(
shps
)
...
@@ -171,6 +176,8 @@ def test_downsample():
...
@@ -171,6 +176,8 @@ def test_downsample():
for
ds
in
(
2
,
2
),
(
3
,
2
),
(
1
,
1
):
for
ds
in
(
2
,
2
),
(
3
,
2
),
(
1
,
1
):
if
ds
[
0
]
>
shp
[
2
]:
continue
if
ds
[
0
]
>
shp
[
2
]:
continue
if
ds
[
1
]
>
shp
[
3
]:
continue
if
ds
[
1
]
>
shp
[
3
]:
continue
#GpuDownsampleFactorMax don't having more then 512 columns in the output tensor
if
float
(
shp
[
3
])
/
ds
[
1
]
>
512
:
continue
for
ignore_border
in
(
True
,
False
):
for
ignore_border
in
(
True
,
False
):
print
'test_downsample'
,
shp
,
ds
,
ignore_border
print
'test_downsample'
,
shp
,
ds
,
ignore_border
ds_op
=
DownsampleFactorMax
(
ds
,
ignore_border
=
ignore_border
)
ds_op
=
DownsampleFactorMax
(
ds
,
ignore_border
=
ignore_border
)
...
@@ -180,12 +187,16 @@ def test_downsample():
...
@@ -180,12 +187,16 @@ def test_downsample():
f2
=
pfunc
([],
ds_op
(
tensor
.
as_tensor_variable
(
a
)),
mode
=
mode_without_gpu
)
f2
=
pfunc
([],
ds_op
(
tensor
.
as_tensor_variable
(
a
)),
mode
=
mode_without_gpu
)
assert
any
([
isinstance
(
node
.
op
,
tcn
.
blas
.
GpuDownsampleFactorMax
)
for
node
in
assert
any
([
isinstance
(
node
.
op
,
tcn
.
blas
.
GpuDownsampleFactorMax
)
for
node
in
f
.
maker
.
env
.
toposort
()])
f
.
maker
.
env
.
toposort
()])
assert
any
([
isinstance
(
node
.
op
,
DownsampleFactorMax
)
for
node
in
f2
.
maker
.
env
.
toposort
()])
assert
numpy
.
allclose
(
f
(),
f2
())
assert
numpy
.
allclose
(
f
(),
f2
())
g
=
pfunc
([],
tensor
.
grad
(
ds_op
(
tensor
.
as_tensor_variable
(
a
))
.
sum
(),
a
),
mode
=
mode_with_gpu
)
g
=
pfunc
([],
tensor
.
grad
(
ds_op
(
tensor
.
as_tensor_variable
(
a
))
.
sum
(),
a
),
mode
=
mode_with_gpu
)
g2
=
pfunc
([],
tensor
.
grad
(
ds_op
(
tensor
.
as_tensor_variable
(
a
))
.
sum
(),
a
),
mode
=
mode_without_gpu
)
g2
=
pfunc
([],
tensor
.
grad
(
ds_op
(
tensor
.
as_tensor_variable
(
a
))
.
sum
(),
a
),
mode
=
mode_without_gpu
)
assert
any
([
isinstance
(
node
.
op
,
tcn
.
blas
.
GpuDownsampleFactorMaxGrad
)
assert
any
([
isinstance
(
node
.
op
,
tcn
.
blas
.
GpuDownsampleFactorMaxGrad
)
for
node
in
g
.
maker
.
env
.
toposort
()])
for
node
in
g
.
maker
.
env
.
toposort
()])
assert
any
([
isinstance
(
node
.
op
,
DownsampleFactorMaxGrad
)
for
node
in
g2
.
maker
.
env
.
toposort
()])
assert
numpy
.
allclose
(
g
(),
g2
())
assert
numpy
.
allclose
(
g
(),
g2
())
#We already check that the gpu version return the same value as the gpu version
#We already check that the gpu version return the same value as the gpu version
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论