Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6807ba72
提交
6807ba72
authored
11月 03, 2016
作者:
abergeron
提交者:
GitHub
11月 03, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5159 from aam-at/gpuarray_pool_grad_grad
Gpuarray pool grad grad
上级
2e06c87e
a888141a
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
398 行增加
和
4 行删除
+398
-4
blas.py
theano/gpuarray/blas.py
+46
-0
opt.py
theano/gpuarray/opt.py
+30
-2
pool_grad_grad.c
theano/gpuarray/pool_grad_grad.c
+189
-0
test_blas.py
theano/gpuarray/tests/test_blas.py
+133
-2
没有找到文件。
theano/gpuarray/blas.py
浏览文件 @
6807ba72
...
...
@@ -1537,6 +1537,52 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM):
return
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
]]
# no connection to height, width, depth
class
GpuDownsampleFactorMaxGradGrad
(
CGpuKernelBase
):
"""
Implement the grad of downsample with max on the gpu.
"""
__props__
=
(
'ignore_border'
,
'mode'
,
'ndim'
)
def
__init__
(
self
,
ignore_border
,
mode
=
'max'
,
ndim
=
2
):
self
.
ndim
=
ndim
self
.
ignore_border
=
ignore_border
self
.
mode
=
mode
CGpuKernelBase
.
__init__
(
self
,
[
'pool_grad_grad.c'
],
'APPLY_SPECIFIC(pool_grad_grad)'
)
assert
self
.
mode
==
'max'
assert
self
.
ndim
in
[
2
,
3
]
def
c_headers
(
self
):
return
[
'gpuarray_api.h'
,
'gpuarray_helper.h'
,
'numpy_compat.h'
]
def
c_header_dirs
(
self
):
return
[
os
.
path
.
dirname
(
__file__
),
pygpu
.
get_include
()]
def
make_node
(
self
,
inp
,
out
,
out_grad
,
ws
,
stride
,
pad
):
ctx_name
=
infer_context_name
(
inp
,
out
,
out_grad
)
inp
=
as_gpuarray_variable
(
inp
,
ctx_name
)
assert
(
inp
.
ndim
in
[
4
,
5
])
out
=
as_gpuarray_variable
(
out
,
ctx_name
)
assert
(
out_grad
.
ndim
in
[
4
,
5
])
out_grad
=
as_gpuarray_variable
(
out_grad
,
ctx_name
)
assert
(
out
.
ndim
in
[
4
,
5
])
assert
(
out_grad
.
ndim
==
inp
.
ndim
)
assert
(
inp
.
ndim
==
out
.
ndim
)
ws
=
as_tensor_variable
(
ws
)
stride
=
as_tensor_variable
(
stride
)
pad
=
as_tensor_variable
(
pad
)
assert
ws
.
type
.
ndim
==
stride
.
type
.
ndim
and
ws
.
type
.
ndim
==
pad
.
type
.
ndim
assert
ws
.
type
.
ndim
==
1
return
Apply
(
self
,
[
inp
,
out
,
out_grad
,
ws
,
stride
,
pad
],
[
inp
.
type
()])
def
get_params
(
self
,
node
):
return
node
.
inputs
[
0
]
.
type
.
context
@inplace_allocempty
(
GpuGemv
,
0
)
def
local_inplace_gpuagemv
(
node
,
inputs
):
return
[
gpugemv_inplace
(
*
inputs
)]
...
...
theano/gpuarray/opt.py
浏览文件 @
6807ba72
...
...
@@ -29,6 +29,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv,
AbstractConv3d
,
AbstractConv3d_gradWeights
,
AbstractConv3d_gradInputs
)
import
theano.tensor.signal.pool
as
pool
from
theano.tests.breakpoint
import
PdbBreakpoint
...
...
@@ -46,7 +47,8 @@ from .blas import (gpu_dot22, GpuGemm, GpuGer, GpuGemmBatch,
gpugemmbatch_no_inplace
,
gpugemv_no_inplace
,
gpugemv_inplace
,
GpuCorrMM
,
GpuCorrMM_gradInputs
,
GpuCorrMM_gradWeights
,
GpuCorr3dMM
,
GpuCorr3dMM_gradInputs
,
GpuCorr3dMM_gradWeights
)
GpuCorr3dMM
,
GpuCorr3dMM_gradInputs
,
GpuCorr3dMM_gradWeights
,
GpuDownsampleFactorMaxGradGrad
)
from
.blocksparse
import
(
GpuSparseBlockGemv
,
GpuSparseBlockOuter
,
gpu_sparse_block_outer
,
gpu_sparse_block_outer_inplace
,
...
...
@@ -62,7 +64,7 @@ from .subtensor import (GpuIncSubtensor, GpuSubtensor,
GpuAdvancedSubtensor1
,
GpuAdvancedIncSubtensor1
,
GpuAdvancedIncSubtensor1_dev20
)
from
.opt_util
import
alpha_merge
,
output_merge
from
.opt_util
import
alpha_merge
,
output_merge
,
pad_dims
,
unpad_dims
_logger
=
logging
.
getLogger
(
"theano.gpuarray.opt"
)
...
...
@@ -1589,6 +1591,32 @@ def local_gpua_lift_abstractconv_graph(op, context_name, inputs, outputs):
return
[
op
(
*
inps
)]
@register_opt
()
@op_lifter
([
pool
.
DownsampleFactorMaxGradGrad
])
@register_opt2
([
pool
.
DownsampleFactorMaxGradGrad
])
def
local_gpu_downsample_factor_max_grad_grad
(
op
,
ctx_name
,
inputs
,
outputs
):
assert
op
.
__props__
==
(
'ignore_border'
,
'mode'
,
'ndim'
)
inp
,
out
,
out_grad
,
ws
,
stride
,
pad
=
inputs
nd
=
op
.
ndim
if
nd
not
in
(
2
,
3
):
return
inp
=
gpu_contiguous
(
as_gpuarray_variable
(
inp
,
ctx_name
))
out
=
gpu_contiguous
(
as_gpuarray_variable
(
out
,
ctx_name
))
out_grad
=
gpu_contiguous
(
as_gpuarray_variable
(
out_grad
,
ctx_name
))
op
=
GpuDownsampleFactorMaxGradGrad
(
op
.
ignore_border
,
op
.
mode
,
op
.
ndim
)
if
inp
.
ndim
==
nd
+
2
:
return
op
(
inp
,
out
,
out_grad
,
ws
,
stride
,
pad
)
else
:
# reshape to 4D or 5D with 2 non-pooling dimensions
inp_padded
=
pad_dims
(
inp
,
2
,
nd
)
out_padded
=
pad_dims
(
out
,
2
,
nd
)
out_grad_padded
=
pad_dims
(
out_grad
,
2
,
nd
)
ret_padded
=
op
(
inp_padded
,
out_padded
,
out_grad_padded
,
ws
,
stride
,
pad
)
return
unpad_dims
(
ret_padded
,
inp
,
2
,
nd
)
@register_opt
(
"low_memory"
)
@local_optimizer
([
GpuCAReduceCuda
])
def
local_gpu_elemwise_careduce
(
node
):
...
...
theano/gpuarray/pool_grad_grad.c
0 → 100644
浏览文件 @
6807ba72
#section kernels
#kernel max_pool2d_grad_grad_kernel : size, size, size, size, size, size, size, *, *, *, size, size, size, size, size, size, * :
KERNEL
void
max_pool2d_grad_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
z
,
GLOBAL_MEM
const
DTYPE_i2
*
gx
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_o0
*
gz
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
pw
=
index
%
pooled_width
;
const
ga_size
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
const
ga_size
c
=
(
index
/
pooled_width
/
pooled_height
)
%
channels
;
const
ga_size
n
=
(
index
/
pooled_width
/
pooled_height
/
channels
);
ga_int
hstart
=
static_cast
<
ga_int
>
(
ph
*
stride_h
)
-
static_cast
<
ga_int
>
(
pad_h
);
hstart
=
max
(
hstart
,
0
);
const
ga_size
hend
=
min
(
hstart
+
kernel_h
,
height
);
ga_int
wstart
=
static_cast
<
ga_int
>
(
pw
*
stride_w
)
-
static_cast
<
ga_int
>
(
pad_w
);
wstart
=
max
(
wstart
,
0
);
const
ga_size
wend
=
min
(
wstart
+
kernel_w
,
width
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
height
*
width
;
const
DTYPE_i0
*
x_slice
=
x
+
offset
;
const
DTYPE_i2
*
gx_slice
=
gx
+
offset
;
DTYPE_o0
gradient
=
0
;
for
(
ga_size
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
ga_size
w
=
wstart
;
w
<
wend
;
++
w
)
{
// maximum in the region
if
(
z
[
index
]
==
x_slice
[
h
*
width
+
w
])
{
gradient
+=
gx_slice
[
h
*
width
+
w
];
}
}
}
gz
[
index
]
=
gradient
;
}
}
#kernel max_pool3d_grad_grad_kernel : size, size, size, size, size, size, size, size, size, *, *, *, size, size, size, size, size, size, size, size, size, * :
KERNEL
void
max_pool3d_grad_grad_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
z
,
GLOBAL_MEM
const
DTYPE_i2
*
gx
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_o0
*
gz
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
pw
=
index
%
pooled_width
;
const
ga_size
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
const
ga_size
pd
=
(
index
/
pooled_width
/
pooled_height
)
%
pooled_depth
;
const
ga_size
c
=
(
index
/
pooled_width
/
pooled_height
/
pooled_depth
)
%
channels
;
const
ga_size
n
=
(
index
/
pooled_width
/
pooled_height
/
pooled_depth
/
channels
);
ga_int
dstart
=
static_cast
<
ga_int
>
(
pd
*
stride_d
)
-
static_cast
<
ga_int
>
(
pad_d
);
dstart
=
max
(
dstart
,
0
);
const
ga_size
dend
=
min
(
dstart
+
kernel_d
,
depth
);
ga_int
hstart
=
static_cast
<
ga_int
>
(
ph
*
stride_h
)
-
static_cast
<
ga_int
>
(
pad_h
);
hstart
=
max
(
hstart
,
0
);
const
ga_size
hend
=
min
(
hstart
+
kernel_h
,
height
);
ga_int
wstart
=
static_cast
<
ga_int
>
(
pw
*
stride_w
)
-
static_cast
<
ga_int
>
(
pad_w
);
wstart
=
max
(
wstart
,
0
);
const
ga_size
wend
=
min
(
wstart
+
kernel_w
,
width
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
depth
*
height
*
width
;
const
DTYPE_i0
*
x_slice
=
x
+
offset
;
const
DTYPE_i2
*
gx_slice
=
gx
+
offset
;
DTYPE_o0
gradient
=
0
;
for
(
ga_size
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
ga_size
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
ga_size
w
=
wstart
;
w
<
wend
;
++
w
)
{
// maximum in the region
if
(
z
[
index
]
==
x_slice
[(
d
*
height
+
h
)
*
width
+
w
])
{
gradient
+=
gx_slice
[(
d
*
height
+
h
)
*
width
+
w
];
}
}
}
}
gz
[
index
]
=
gradient
;
}
}
#section support_code_struct
int
APPLY_SPECIFIC
(
pool_grad_grad
)(
PyGpuArrayObject
*
x
,
PyGpuArrayObject
*
z
,
PyGpuArrayObject
*
gx
,
PyArrayObject
*
ws
,
PyArrayObject
*
stride
,
PyArrayObject
*
pad
,
PyGpuArrayObject
**
gz
,
PyGpuContextObject
*
ctx
)
{
if
(
!
GpuArray_IS_C_CONTIGUOUS
(
&
x
->
ga
)
||
!
GpuArray_IS_C_CONTIGUOUS
(
&
z
->
ga
)
||
!
GpuArray_IS_C_CONTIGUOUS
(
&
gx
->
ga
))
{
PyErr_Format
(
PyExc_ValueError
,
"GpuPoolingGradGrad: requires data to be C-contiguous"
);
return
1
;
}
size_t
ndims
=
PyArray_DIM
(
ws
,
0
);
if
(
PyGpuArray_NDIM
(
x
)
!=
ndims
+
2
||
PyGpuArray_NDIM
(
z
)
!=
ndims
+
2
||
PyGpuArray_NDIM
(
gx
)
!=
ndims
+
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"GpuPoolingGradGrad: rank error"
);
return
1
;
}
if
(
theano_prep_output
(
gz
,
PyGpuArray_NDIM
(
z
),
PyGpuArray_DIMS
(
z
),
z
->
ga
.
typecode
,
GA_C_ORDER
,
ctx
)
!=
0
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuPoolingGradGrad: failed to allocate memory"
);
return
1
;
}
{
// scope for running kernel
size_t
w
[
3
];
size_t
s
[
3
];
size_t
p
[
3
];
for
(
int
i
=
0
;
i
<
ndims
;
i
++
)
{
w
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
ws
,
i
));
s
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
stride
,
i
));
p
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
pad
,
i
));
}
size_t
max_threads_dim
;
int
err
;
const
size_t
*
z_dims
=
PyGpuArray_DIMS
(
z
);
const
size_t
*
x_dims
=
PyGpuArray_DIMS
(
x
);
// Get the max threads per blocks
err
=
gpucontext_property
(
ctx
->
ctx
,
GA_CTX_PROP_MAXLSIZE0
,
&
max_threads_dim
);
if
(
err
!=
GA_NO_ERROR
){
PyErr_SetString
(
PyExc_RuntimeError
,
"Could not fetch max_threads_dims"
);
return
1
;
}
size_t
threads_per_block
=
max_threads_dim
;
if
(
ndims
==
2
)
{
size_t
num_kernels
=
z_dims
[
0
]
*
z_dims
[
1
]
*
z_dims
[
2
]
*
z_dims
[
3
];
size_t
n_blocks
=
(
num_kernels
+
threads_per_block
-
1
)
/
threads_per_block
;
err
=
max_pool2d_grad_grad_kernel_call
(
1
,
&
n_blocks
,
&
threads_per_block
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
x_dims
[
2
],
x_dims
[
3
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gx
->
ga
.
data
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
gz
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPoolingGradGrad: max_pool2d_grad_grad_kernel %s."
,
GpuKernel_error
(
&
k_max_pool2d_grad_grad_kernel
,
err
));
return
1
;
}
}
else
if
(
ndims
==
3
)
{
size_t
num_kernels
=
z_dims
[
0
]
*
z_dims
[
1
]
*
z_dims
[
2
]
*
z_dims
[
3
]
*
z_dims
[
4
];
size_t
n_blocks
=
(
num_kernels
+
threads_per_block
-
1
)
/
threads_per_block
;
err
=
max_pool3d_grad_grad_kernel_call
(
1
,
&
n_blocks
,
&
threads_per_block
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
x
->
ga
.
data
,
z
->
ga
.
data
,
gx
->
ga
.
data
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
gz
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuPoolingGradGrad: max_pool3d_grad_grad_kernel %s."
,
GpuKernel_error
(
&
k_max_pool3d_grad_grad_kernel
,
err
));
return
1
;
}
}
}
return
0
;
}
theano/gpuarray/tests/test_blas.py
浏览文件 @
6807ba72
...
...
@@ -2,21 +2,26 @@ from __future__ import absolute_import, print_function, division
from
unittest
import
TestCase
from
nose.plugins.skip
import
SkipTest
import
itertools
import
copy
import
numpy
import
theano
from
theano
import
gradient
from
theano
import
tensor
from
theano.tests
import
unittest_tools
as
utt
from
theano.tensor.blas
import
gemv_inplace
,
gemm_inplace
,
_dot22
,
batched_dot
from
theano.tensor.tests.test_blas
import
TestGer
,
BaseGemv
from
theano.tensor.signal.pool
import
Pool
,
DownsampleFactorMaxGradGrad
from
..
import
gpuarray_shared_constructor
from
.config
import
mode_with_gpu
from
.config
import
mode_with_gpu
,
mode_without_gpu
from
.test_basic_ops
import
makeTester
,
rand
from
..blas
import
(
gpugemv_inplace
,
gpugemv_no_inplace
,
gpugemm_inplace
,
gpugemmbatch_no_inplace
,
gpuger_inplace
,
gpuger_no_inplace
,
GpuGer
,
gpu_dot22
)
GpuGer
,
gpu_dot22
,
GpuDownsampleFactorMaxGradGrad
)
GpuGemvTester
=
makeTester
(
...
...
@@ -128,3 +133,129 @@ GpuDot22Tester = makeTester(
# test9=[rand(0, 0), rand(0, 0)],
)
)
def
test_max_pool2d_grad_grad
():
shps
=
[(
1
,
12
),
(
1
,
1
,
12
),
(
1
,
1
,
1
,
12
),
(
1
,
1
,
2
,
2
),
(
1
,
1
,
1
,
1
),
(
1
,
1
,
4
,
4
),
(
1
,
1
,
10
,
11
),
(
1
,
2
,
2
,
2
),
(
3
,
5
,
4
,
4
),
(
25
,
1
,
7
,
7
),
(
1
,
1
,
12
,
12
),
(
1
,
1
,
2
,
14
),
(
1
,
1
,
12
,
14
),
(
1
,
1
,
14
,
14
),
(
1
,
1
,
16
,
16
),
(
1
,
1
,
18
,
18
),
(
1
,
1
,
24
,
24
),
(
1
,
6
,
24
,
24
),
(
10
,
1
,
24
,
24
),
(
10
,
6
,
24
,
24
),
(
30
,
6
,
12
,
12
),
(
30
,
2
,
24
,
24
),
(
30
,
6
,
24
,
24
),
(
10
,
10
,
10
,
11
),
(
1
,
1
,
10
,
1025
),
(
1
,
1
,
10
,
1023
),
(
1
,
1
,
1025
,
10
),
(
1
,
1
,
1023
,
10
),
]
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
.
shuffle
(
shps
)
test_ds
=
(
2
,
2
),
(
3
,
2
),
(
1
,
1
)
test_st
=
(
2
,
2
),
(
3
,
2
),
(
1
,
1
)
for
shp
in
shps
:
for
ds
,
st
in
itertools
.
product
(
test_ds
,
test_st
):
if
ds
[
0
]
>
shp
[
-
2
]
or
ds
[
1
]
>
shp
[
-
1
]:
continue
for
ignore_border
,
pad
in
zip
((
True
,
False
),
[(
1
,
1
),
(
0
,
0
)]):
if
pad
[
0
]
>=
ds
[
0
]
or
pad
[
1
]
>=
ds
[
1
]:
continue
# print('test_downsample', shp, ds, st, pad, ignore_border)
ds_op
=
Pool
(
ndim
=
len
(
ds
),
ignore_border
=
ignore_border
)
a
=
theano
.
shared
(
rand
(
*
shp
),
'a'
)
ggf
=
gradient
.
Lop
(
tensor
.
grad
((
ds_op
(
tensor
.
as_tensor_variable
(
a
),
ds
,
st
,
pad
)
**
2
)
.
sum
(),
a
),
a
,
a
)
ref_mode
=
copy
.
copy
(
mode_without_gpu
)
ref_mode
.
check_py_code
=
False
gpu_mode
=
copy
.
copy
(
mode_with_gpu
)
gpu_mode
.
check_py_code
=
False
gg
=
theano
.
function
([],
ggf
,
mode
=
gpu_mode
)
gg2
=
theano
.
function
([],
ggf
,
mode
=
ref_mode
)
assert
any
([
isinstance
(
node
.
op
,
GpuDownsampleFactorMaxGradGrad
)
for
node
in
gg
.
maker
.
fgraph
.
toposort
()
])
assert
any
([
isinstance
(
node
.
op
,
DownsampleFactorMaxGradGrad
)
for
node
in
gg2
.
maker
.
fgraph
.
toposort
()
])
assert
numpy
.
allclose
(
gg
(),
gg2
()),
(
shp
,
ds
,
st
,
ignore_border
)
def
test_max_pool3d_grad_grad
():
shps
=
[(
1
,
1
,
12
),
(
1
,
1
,
1
,
1
,
1
),
(
1
,
1
,
1
,
1
,
1025
),
(
1
,
1
,
2
,
2
,
2
),
(
1
,
1
,
7
,
7
,
7
),
(
1
,
1
,
9
,
10
,
11
),
(
1
,
6
,
18
,
18
,
18
),
(
1
,
1
,
6
,
24
,
24
),
(
1
,
10
,
1
,
24
,
24
),
(
1
,
10
,
6
,
24
,
24
),
(
1
,
30
,
6
,
12
,
12
),
(
1
,
30
,
2
,
24
,
24
),
(
1
,
30
,
6
,
24
,
24
),
(
1
,
10
,
10
,
10
,
11
),
(
1
,
1
,
10
,
10
,
1025
),
(
1
,
1
,
10
,
10
,
1023
),
(
1
,
1
,
10
,
1025
,
10
),
(
1
,
1
,
10
,
1023
,
10
),
]
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
.
shuffle
(
shps
)
test_ds
=
(
2
,
2
,
2
),
(
3
,
2
,
3
),
(
1
,
1
,
1
)
test_st
=
(
2
,
2
,
2
),
(
2
,
3
,
2
),
(
1
,
1
,
1
)
for
shp
in
shps
:
for
ds
,
st
in
itertools
.
product
(
test_ds
,
test_st
):
if
ds
[
0
]
>
shp
[
-
3
]
or
ds
[
1
]
>
shp
[
-
2
]
or
ds
[
2
]
>
shp
[
-
1
]:
continue
for
ignore_border
,
pad
in
zip
((
True
,
False
),
[(
1
,
1
,
1
),
(
0
,
0
,
0
)]):
if
pad
[
0
]
>=
ds
[
0
]
or
pad
[
1
]
>=
ds
[
1
]
or
pad
[
2
]
>=
ds
[
2
]:
continue
# print('test_downsample', shp, ds, st, pad, ignore_border)
ds_op
=
Pool
(
ndim
=
len
(
ds
),
ignore_border
=
ignore_border
)
a
=
theano
.
shared
(
rand
(
*
shp
),
'a'
)
ggf
=
gradient
.
Lop
(
tensor
.
grad
((
ds_op
(
tensor
.
as_tensor_variable
(
a
),
ds
,
st
,
pad
)
**
2
)
.
sum
(),
a
),
a
,
a
)
ref_mode
=
copy
.
copy
(
mode_without_gpu
)
ref_mode
.
check_py_code
=
False
gpu_mode
=
copy
.
copy
(
mode_with_gpu
)
gpu_mode
.
check_py_code
=
False
gg
=
theano
.
function
([],
ggf
,
mode
=
gpu_mode
)
gg2
=
theano
.
function
([],
ggf
,
mode
=
ref_mode
)
assert
any
([
isinstance
(
node
.
op
,
GpuDownsampleFactorMaxGradGrad
)
for
node
in
gg
.
maker
.
fgraph
.
toposort
()
])
assert
any
([
isinstance
(
node
.
op
,
DownsampleFactorMaxGradGrad
)
for
node
in
gg2
.
maker
.
fgraph
.
toposort
()
])
assert
numpy
.
allclose
(
gg
(),
gg2
()),
(
shp
,
ds
,
st
,
ignore_border
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论