Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7b5919e9
提交
7b5919e9
authored
1月 25, 2017
作者:
Frédéric Bastien
提交者:
GitHub
1月 25, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5323 from aam-at/max_pool_rop
Pooling rop
上级
51ac3abd
2dabc825
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
818 行增加
和
3 行删除
+818
-3
opt.py
theano/gpuarray/opt.py
+24
-1
pool.py
theano/gpuarray/pool.py
+89
-0
pool_max_rop.c
theano/gpuarray/pool_max_rop.c
+197
-0
test_pool.py
theano/gpuarray/tests/test_pool.py
+32
-2
pool.py
theano/tensor/signal/pool.py
+433
-0
test_rop.py
theano/tests/test_rop.py
+43
-0
没有找到文件。
theano/gpuarray/opt.py
浏览文件 @
7b5919e9
...
...
@@ -51,7 +51,7 @@ from .blas import (gpu_dot22, GpuGemm, GpuGer, GpuGemmBatch,
gpugemv_no_inplace
,
gpugemv_inplace
,
GpuCorrMM
,
GpuCorrMM_gradInputs
,
GpuCorrMM_gradWeights
,
GpuCorr3dMM
,
GpuCorr3dMM_gradInputs
,
GpuCorr3dMM_gradWeights
)
from
.pool
import
(
GpuPool
,
GpuMaxPoolGrad
,
GpuAveragePoolGrad
,
from
.pool
import
(
GpuPool
,
GpuMaxPoolGrad
,
GpuAveragePoolGrad
,
GpuMaxPoolRop
,
GpuDownsampleFactorMaxGradGrad
)
from
.blocksparse
import
(
GpuSparseBlockGemv
,
GpuSparseBlockOuter
,
gpu_sparse_block_outer
,
...
...
@@ -1747,6 +1747,29 @@ def local_gpu_downsample_factor_max_grad_grad(op, ctx_name, inputs, outputs):
return
unpad_dims
(
ret_padded
,
inp
,
2
,
nd
)
@register_opt
()
@op_lifter
([
pool
.
MaxPoolRop
])
@register_opt2
([
pool
.
MaxPoolRop
])
def
local_gpu_max_pool_rop
(
op
,
ctx_name
,
inputs
,
outputs
):
assert
op
.
__props__
==
(
'ignore_border'
,
'mode'
,
'ndim'
)
inp
,
eval_inp
,
ws
,
stride
,
pad
=
inputs
nd
=
op
.
ndim
if
nd
not
in
(
2
,
3
):
return
inp
=
gpu_contiguous
(
as_gpuarray_variable
(
inp
,
ctx_name
))
eval_inp
=
gpu_contiguous
(
as_gpuarray_variable
(
eval_inp
,
ctx_name
))
op
=
GpuMaxPoolRop
(
op
.
ignore_border
,
op
.
mode
,
op
.
ndim
)
if
inp
.
ndim
==
nd
+
2
:
return
op
(
inp
,
eval_inp
,
ws
,
stride
,
pad
)
else
:
# reshape to 4D or 5D with 2 non-pooling dimensions
inp_padded
=
pad_dims
(
inp
,
2
,
nd
)
eval_inp_padded
=
pad_dims
(
eval_inp
,
2
,
nd
)
ret_padded
=
op
(
inp_padded
,
eval_inp_padded
,
ws
,
stride
,
pad
)
return
unpad_dims
(
ret_padded
,
inp
,
2
,
nd
)
@register_opt
(
"low_memory"
)
@local_optimizer
([
GpuCAReduceCuda
])
def
local_gpu_elemwise_careduce
(
node
):
...
...
theano/gpuarray/pool.py
浏览文件 @
7b5919e9
...
...
@@ -112,6 +112,26 @@ class GpuPool(CGpuKernelBase):
def
connection_pattern
(
self
,
node
):
return
[[
1
],
[
0
],
[
0
],
[
0
]]
def
R_op
(
self
,
inputs
,
eval_points
):
if
self
.
mode
!=
'max'
:
# Rop for average or sum is simply pooling evaluated at eval point
eval_inputs
=
[
eval_points
[
0
]]
+
inputs
[
1
:]
return
[
self
(
*
eval_inputs
)]
# R_op can receive None as eval_points.
# That mean there is no diferientiable path through that input
# If this imply that you cannot compute some outputs,
# return None for those.
if
eval_points
[
0
]
is
None
:
return
[
None
]
z
=
self
(
*
inputs
)
x
,
ws
,
stride
,
pad
=
inputs
return
[
GpuDownsampleFactorMaxGradGrad
(
self
.
ignore_border
,
self
.
mode
,
self
.
ndim
)(
x
,
z
,
eval_points
[
0
],
ws
,
stride
,
pad
)
]
class
GpuMaxPoolGrad
(
CGpuKernelBase
):
"""
...
...
@@ -334,3 +354,72 @@ class GpuDownsampleFactorMaxGradGrad(CGpuKernelBase):
def
connection_pattern
(
self
,
node
):
return
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
]]
class
GpuMaxPoolRop
(
CGpuKernelBase
):
"""
Implements the R-operator for the downsample operation.
"""
__props__
=
(
'ignore_border'
,
'mode'
,
'ndim'
)
def
__init__
(
self
,
ignore_border
,
mode
=
'max'
,
ndim
=
2
):
self
.
ndim
=
ndim
self
.
ignore_border
=
ignore_border
self
.
mode
=
mode
CGpuKernelBase
.
__init__
(
self
,
[
'pool_max_rop.c'
],
'APPLY_SPECIFIC(max_pool_rop)'
)
assert
mode
==
'max'
assert
ndim
in
[
2
,
3
]
def
c_headers
(
self
):
return
[
'gpuarray_api.h'
,
'gpuarray_helper.h'
,
'numpy_compat.h'
]
def
c_header_dirs
(
self
):
return
[
os
.
path
.
dirname
(
__file__
),
pygpu
.
get_include
()]
def
make_node
(
self
,
inp
,
eval_point
,
ws
,
stride
=
None
,
pad
=
None
):
ctx_name
=
infer_context_name
(
inp
)
nd
=
self
.
ndim
inp
=
as_gpuarray_variable
(
inp
,
ctx_name
)
assert
(
inp
.
ndim
==
nd
+
2
)
eval_point
=
as_gpuarray_variable
(
eval_point
,
ctx_name
)
assert
(
eval_point
.
ndim
==
nd
+
2
)
if
stride
is
None
:
stride
=
ws
if
pad
is
None
:
pad
=
(
0
,)
*
nd
elif
isinstance
(
pad
,
(
tuple
,
list
)):
if
max
(
pad
)
!=
0
and
not
self
.
ignore_border
:
raise
ValueError
(
'Padding works only with ignore_border=True'
)
if
isinstance
(
ws
,
(
tuple
,
list
)):
if
any
(
pad
[
i
]
>=
ws
[
i
]
for
i
in
range
(
nd
)):
raise
ValueError
(
'Padding must be smaller than strides'
)
ws
=
as_tensor_variable
(
ws
)
stride
=
as_tensor_variable
(
stride
)
pad
=
as_tensor_variable
(
pad
)
assert
ws
.
ndim
==
stride
.
ndim
and
ws
.
ndim
==
pad
.
ndim
assert
ws
.
ndim
==
1
if
not
ws
.
dtype
.
startswith
(
'int'
):
raise
TypeError
(
'Window shape parameters must be ints.'
)
if
not
stride
.
dtype
.
startswith
(
'int'
):
raise
TypeError
(
'Stride parameters must be ints.'
)
if
not
pad
.
dtype
.
startswith
(
'int'
):
raise
TypeError
(
'Padding parameters must be ints.'
)
return
Apply
(
self
,
[
inp
,
eval_point
,
ws
,
stride
,
pad
],
[
eval_point
.
type
()])
def
get_params
(
self
,
node
):
return
node
.
inputs
[
0
]
.
type
.
context
def
get_op_params
(
self
):
ignore_border
=
int
(
self
.
ignore_border
)
return
[(
'IGNORE_BORDER'
,
ignore_border
)]
def
infer_shape
(
self
,
node
,
in_shapes
):
ws
,
stride
,
pad
=
[
node
.
inputs
[
2
],
node
.
inputs
[
3
],
node
.
inputs
[
4
]]
shp
=
Pool
.
out_shape
(
in_shapes
[
0
],
ws
,
self
.
ignore_border
,
stride
,
pad
,
self
.
ndim
)
return
[
shp
]
theano/gpuarray/pool_max_rop.c
0 → 100644
浏览文件 @
7b5919e9
#section kernels
#kernel max_pool2d_rop_kernel : size, size, size, size, size, size, size, *, *, size, size, size, size, size, size, * :
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool2d_rop_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
ex
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_o0
*
z
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
pw
=
index
%
pooled_width
;
const
ga_size
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
const
ga_size
c
=
(
index
/
pooled_width
/
pooled_height
)
%
channels
;
const
ga_size
n
=
(
index
/
pooled_width
/
pooled_height
/
channels
);
ga_int
hstart
=
static_cast
<
ga_int
>
(
ph
*
stride_h
)
-
static_cast
<
ga_int
>
(
pad_h
);
const
ga_size
hend
=
min
(
hstart
+
kernel_h
,
height
);
ga_int
wstart
=
static_cast
<
ga_int
>
(
pw
*
stride_w
)
-
static_cast
<
ga_int
>
(
pad_w
);
const
ga_size
wend
=
min
(
wstart
+
kernel_w
,
width
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
height
*
width
;
const
DTYPE_i0
*
x_slice
=
x
+
offset
;
const
DTYPE_i1
*
ex_slice
=
ex
+
offset
;
DTYPE_o0
maxval
=
x_slice
[
hstart
*
width
+
wstart
];
DTYPE_o0
collector
=
ex_slice
[
hstart
*
width
+
wstart
];
for
(
ga_size
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
ga_size
w
=
wstart
;
w
<
wend
;
++
w
)
{
// maximum in the region
if
(
x_slice
[
h
*
width
+
w
]
>
maxval
)
{
maxval
=
x_slice
[
h
*
width
+
w
];
collector
=
ex_slice
[
h
*
width
+
w
];
}
}
}
z
[
index
]
=
collector
;
}
}
#kernel max_pool3d_rop_kernel : size, size, size, size, size, size, size, size, size, *, *, size, size, size, size, size, size, size, size, size, * :
// (adopted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/pooling_layer.cu)
KERNEL
void
max_pool3d_rop_kernel
(
const
ga_size
nthreads
,
const
ga_size
num
,
const
ga_size
channels
,
const
ga_size
pooled_depth
,
const
ga_size
pooled_height
,
const
ga_size
pooled_width
,
const
ga_size
depth
,
const
ga_size
height
,
const
ga_size
width
,
GLOBAL_MEM
const
DTYPE_i0
*
x
,
GLOBAL_MEM
const
DTYPE_i1
*
ex
,
const
ga_size
kernel_d
,
const
ga_size
kernel_h
,
const
ga_size
kernel_w
,
const
ga_size
stride_d
,
const
ga_size
stride_h
,
const
ga_size
stride_w
,
const
ga_size
pad_d
,
const
ga_size
pad_h
,
const
ga_size
pad_w
,
GLOBAL_MEM
DTYPE_o0
*
z
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
pw
=
index
%
pooled_width
;
const
ga_size
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
const
ga_size
pd
=
(
index
/
pooled_width
/
pooled_height
)
%
pooled_depth
;
const
ga_size
c
=
(
index
/
pooled_width
/
pooled_height
/
pooled_depth
)
%
channels
;
const
ga_size
n
=
(
index
/
pooled_width
/
pooled_height
/
pooled_depth
/
channels
);
ga_int
dstart
=
static_cast
<
ga_int
>
(
pd
*
stride_d
)
-
static_cast
<
ga_int
>
(
pad_d
);
const
ga_size
dend
=
min
(
dstart
+
kernel_d
,
depth
);
ga_int
hstart
=
static_cast
<
ga_int
>
(
ph
*
stride_h
)
-
static_cast
<
ga_int
>
(
pad_h
);
const
ga_size
hend
=
min
(
hstart
+
kernel_h
,
height
);
ga_int
wstart
=
static_cast
<
ga_int
>
(
pw
*
stride_w
)
-
static_cast
<
ga_int
>
(
pad_w
);
const
ga_size
wend
=
min
(
wstart
+
kernel_w
,
width
);
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
const
ga_size
offset
=
(
n
*
channels
+
c
)
*
depth
*
height
*
width
;
const
DTYPE_i0
*
x_slice
=
x
+
offset
;
const
DTYPE_i1
*
ex_slice
=
ex
+
offset
;
DTYPE_o0
maxval
=
x_slice
[(
dstart
*
height
+
hstart
)
*
width
+
wstart
];
DTYPE_o0
collector
=
ex_slice
[(
dstart
*
height
+
hstart
)
*
width
+
wstart
];
for
(
ga_size
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
ga_size
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
ga_size
w
=
wstart
;
w
<
wend
;
++
w
)
{
// maximum in the region
if
(
x_slice
[(
d
*
height
+
h
)
*
width
+
w
]
>
maxval
)
{
maxval
=
x_slice
[(
d
*
height
+
h
)
*
width
+
w
];
collector
=
ex_slice
[(
d
*
height
+
h
)
*
width
+
w
];
}
}
}
}
z
[
index
]
=
collector
;
}
}
#section support_code
// output shape for a given input padded shape, window shape and stride
#define OUTPUT_DIMS(in_dim, ws, st) \
(IGNORE_BORDER ? (in_dim - ws)/st + 1 : \
(st > ws ? (in_dim - 1)/st + 1 : \
std::max<size_t>(0, (in_dim - 1 - ws + st)/st) + 1))
#section support_code_struct
int
APPLY_SPECIFIC
(
max_pool_rop
)(
PyGpuArrayObject
*
x
,
PyGpuArrayObject
*
ex
,
PyArrayObject
*
ws
,
PyArrayObject
*
stride
,
PyArrayObject
*
pad
,
PyGpuArrayObject
**
z
,
PyGpuContextObject
*
ctx
)
{
if
(
!
GpuArray_IS_C_CONTIGUOUS
(
&
x
->
ga
)
||
!
GpuArray_IS_C_CONTIGUOUS
(
&
ex
->
ga
))
{
PyErr_Format
(
PyExc_ValueError
,
"GpuMaxPoolRop: requires data to be C-contiguous"
);
return
1
;
}
size_t
ndims
=
PyArray_DIM
(
ws
,
0
);
if
(
PyGpuArray_NDIM
(
x
)
!=
ndims
+
2
||
PyGpuArray_NDIM
(
ex
)
!=
ndims
+
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"GpuMaxPoolRop: rank error"
);
return
1
;
}
// prepare output
const
size_t
*
x_dims
=
PyGpuArray_DIMS
(
x
);
size_t
z_dims
[
5
];
// avoid warning if use 2 + nd
size_t
w
[
3
];
size_t
s
[
3
];
size_t
p
[
3
];
z_dims
[
0
]
=
x_dims
[
0
];
z_dims
[
1
]
=
x_dims
[
1
];
int
nonzero_padding
=
0
;
for
(
int
i
=
0
;
i
<
ndims
;
i
++
)
{
w
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
ws
,
i
));
s
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
stride
,
i
));
p
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
pad
,
i
));
z_dims
[
2
+
i
]
=
OUTPUT_DIMS
(
x_dims
[
2
+
i
]
+
2
*
p
[
i
],
w
[
i
],
s
[
i
]);
if
(
p
[
i
]
>
0
)
{
nonzero_padding
=
1
;
}
}
if
(
!
IGNORE_BORDER
&&
nonzero_padding
)
{
PyErr_SetString
(
PyExc_ValueError
,
"GpuMaxPoolRop: padding works only with ignore_border=True"
);
return
1
;
}
if
(
theano_prep_output
(
z
,
PyGpuArray_NDIM
(
ex
),
z_dims
,
ex
->
ga
.
typecode
,
GA_C_ORDER
,
ctx
)
!=
0
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMaxPoolRop: failed to allocate memory"
);
return
1
;
}
{
// scope for running kernel
int
err
;
if
(
ndims
==
2
)
{
size_t
num_kernels
=
z_dims
[
0
]
*
z_dims
[
1
]
*
z_dims
[
2
]
*
z_dims
[
3
];
err
=
max_pool2d_rop_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
x_dims
[
2
],
x_dims
[
3
],
x
->
ga
.
data
,
ex
->
ga
.
data
,
w
[
0
],
w
[
1
],
s
[
0
],
s
[
1
],
p
[
0
],
p
[
1
],
(
*
z
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolRop: max_pool2d_rop_kernel %s."
,
GpuKernel_error
(
&
k_max_pool2d_rop_kernel
,
err
));
return
1
;
}
}
else
if
(
ndims
==
3
)
{
size_t
num_kernels
=
z_dims
[
0
]
*
z_dims
[
1
]
*
z_dims
[
2
]
*
z_dims
[
3
]
*
z_dims
[
4
];
err
=
max_pool3d_rop_kernel_scall
(
1
,
&
num_kernels
,
0
,
num_kernels
,
z_dims
[
0
],
z_dims
[
1
],
z_dims
[
2
],
z_dims
[
3
],
z_dims
[
4
],
x_dims
[
2
],
x_dims
[
3
],
x_dims
[
4
],
x
->
ga
.
data
,
ex
->
ga
.
data
,
w
[
0
],
w
[
1
],
w
[
2
],
s
[
0
],
s
[
1
],
s
[
2
],
p
[
0
],
p
[
1
],
p
[
2
],
(
*
z
)
->
ga
.
data
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMaxPoolRop: max_pool3d_rop_kernel %s."
,
GpuKernel_error
(
&
k_max_pool2d_rop_kernel
,
err
));
return
1
;
}
}
}
return
0
;
}
theano/gpuarray/tests/test_pool.py
浏览文件 @
7b5919e9
...
...
@@ -133,11 +133,26 @@ def test_pool2d():
assert
numpy
.
allclose
(
g
(),
g2
()),
(
shp
,
ws
,
st
,
pad
,
mode
,
ignore_border
)
# test grad grad for max pooling
# test
rop and
grad grad for max pooling
# for average pooling grad grad is just average pooling grad
if
mode
!=
'max'
:
continue
ea
=
theano
.
shared
(
rand
(
*
shp
),
'ea'
)
gr
=
theano
.
function
([],
tensor
.
Rop
(
a_pooled
,
a
,
ea
),
mode
=
gpu_mode
)
gr2
=
theano
.
function
([],
tensor
.
Rop
(
a_pooled
,
a
,
ea
),
mode
=
ref_mode
)
assert
any
([
isinstance
(
node
.
op
,
GpuDownsampleFactorMaxGradGrad
)
for
node
in
gr
.
maker
.
fgraph
.
toposort
()
])
assert
any
([
isinstance
(
node
.
op
,
DownsampleFactorMaxGradGrad
)
for
node
in
gr2
.
maker
.
fgraph
.
toposort
()
])
assert
numpy
.
allclose
(
gr
(),
gr2
()),
(
shp
,
ws
,
st
,
pad
,
mode
,
ignore_border
)
ggf
=
gradient
.
Lop
(
tensor
.
grad
((
a_pooled
**
2
)
.
sum
(),
a
),
a
,
a
)
gg
=
theano
.
function
([],
ggf
,
mode
=
gpu_mode
)
...
...
@@ -228,11 +243,26 @@ def test_pool3d():
assert
numpy
.
allclose
(
g
(),
g2
()),
(
shp
,
ws
,
st
,
pad
,
mode
,
ignore_border
)
# test grad grad for max pooling
# test
rop and
grad grad for max pooling
# for average pooling grad grad is just average pooling grad
if
mode
!=
'max'
:
continue
ea
=
theano
.
shared
(
rand
(
*
shp
),
'ea'
)
gr
=
theano
.
function
([],
tensor
.
Rop
(
a_pooled
,
a
,
ea
),
mode
=
gpu_mode
)
gr2
=
theano
.
function
([],
tensor
.
Rop
(
a_pooled
,
a
,
ea
),
mode
=
ref_mode
)
assert
any
([
isinstance
(
node
.
op
,
GpuDownsampleFactorMaxGradGrad
)
for
node
in
gr
.
maker
.
fgraph
.
toposort
()
])
assert
any
([
isinstance
(
node
.
op
,
DownsampleFactorMaxGradGrad
)
for
node
in
gr2
.
maker
.
fgraph
.
toposort
()
])
assert
numpy
.
allclose
(
gr
(),
gr2
()),
(
shp
,
ws
,
st
,
pad
,
mode
,
ignore_border
)
ggf
=
gradient
.
Lop
(
tensor
.
grad
((
a_pooled
**
2
)
.
sum
(),
a
),
a
,
a
)
gg
=
theano
.
function
([],
ggf
,
mode
=
gpu_mode
)
...
...
theano/tensor/signal/pool.py
浏览文件 @
7b5919e9
...
...
@@ -580,6 +580,26 @@ class Pool(OpenMPOp):
def
connection_pattern
(
self
,
node
):
return
[[
1
],
[
0
],
[
0
],
[
0
]]
def
R_op
(
self
,
inputs
,
eval_points
):
if
self
.
mode
!=
'max'
:
# Rop for average or sum is simply pooling evaluated at eval point
eval_inputs
=
[
eval_points
[
0
]]
+
inputs
[
1
:]
return
[
self
(
*
eval_inputs
)]
# R_op can receive None as eval_points.
# That mean there is no diferientiable path through that input
# If this imply that you cannot compute some outputs,
# return None for those.
if
eval_points
[
0
]
is
None
:
return
[
None
]
z
=
self
(
*
inputs
)
x
,
ws
,
stride
,
pad
=
inputs
return
[
DownsampleFactorMaxGradGrad
(
self
.
ignore_border
,
self
.
mode
,
self
.
ndim
)(
x
,
z
,
eval_points
[
0
],
ws
,
stride
,
pad
)
]
def
c_headers
(
self
):
headers
=
[
'<algorithm>'
]
headers
+=
super
(
Pool
,
self
)
.
c_headers
()
...
...
@@ -2006,3 +2026,416 @@ class DownsampleFactorMaxGradGrad(OpenMPOp):
def
c_code_cache_version
(
self
):
return
(
0
,
4
,
self
.
openmp
)
class
MaxPoolRop
(
OpenMPOp
):
"""
Implements the R-operator for the downsample operation.
Parameters
----------
ws : list or tuple of N ints
Downsample factor over rows, columns etc.
ws indicates the size of the pooling region.
ignore_border : bool
If ws doesn't divide imgshape, do we include an extra row/col/slice
of partial downsampling (False) or ignore it (True).
stride : list or tuple of N ints or None
Stride size, which is the number of shifts over rows/cols/slices to get the
next pool region. If stride is None, it is considered equal to ws
(no overlap on pooling regions).
pad : tuple of N ints or None
For each downsampling dimension, this specifies the number of zeros to
add as padding on both sides. For 2D and (pad_h, pad_w), pad_h specifies the
size of the top and bottom margins, pad_w specifies the size of the left and
right margins. No padding is added if pad is None.
mode : {'max', 'sum', 'average_inc_pad', 'average_exc_pad'}
('average_inc_pad' excludes the padding from the count,
'average_exc_pad' include it)
ndim : int
The number of pooling dimensions N.
The default is 2.
"""
__props__
=
(
'ignore_border'
,
'mode'
,
'ndim'
)
def
__init__
(
self
,
ignore_border
=
False
,
mode
=
'max'
,
ndim
=
2
,
openmp
=
None
):
super
(
MaxPoolRop
,
self
)
.
__init__
(
openmp
=
openmp
)
self
.
ndim
=
ndim
self
.
ignore_border
=
ignore_border
self
.
mode
=
mode
assert
mode
==
'max'
def
make_node
(
self
,
x
,
eval_point
,
ws
,
stride
=
None
,
pad
=
None
):
# TODO: consider restricting the dtype?
x
=
tensor
.
as_tensor_variable
(
x
)
eval_point
=
tensor
.
as_tensor_variable
(
eval_point
)
nd
=
self
.
ndim
if
stride
is
None
:
stride
=
ws
if
pad
is
None
:
pad
=
(
0
,)
*
nd
elif
isinstance
(
pad
,
(
tuple
,
list
)):
if
max
(
pad
)
!=
0
and
not
self
.
ignore_border
:
raise
NotImplementedError
(
'padding works only with ignore_border=True'
)
if
isinstance
(
ws
,
(
tuple
,
list
)):
if
any
(
pad
[
i
]
>=
ws
[
i
]
for
i
in
range
(
nd
)):
raise
NotImplementedError
(
'padding must be smaller than strides'
)
ws
=
tensor
.
as_tensor_variable
(
ws
)
stride
=
tensor
.
as_tensor_variable
(
stride
)
pad
=
tensor
.
as_tensor_variable
(
pad
)
assert
ws
.
ndim
==
1
assert
stride
.
ndim
==
1
assert
pad
.
ndim
==
1
if
x
.
type
.
ndim
<
nd
:
raise
TypeError
()
if
not
ws
.
dtype
.
startswith
(
'int'
):
raise
TypeError
(
'Pool downsample parameters must be ints.'
)
if
not
stride
.
dtype
.
startswith
(
'int'
):
raise
TypeError
(
'Stride parameters must be ints.'
)
if
not
pad
.
dtype
.
startswith
(
'int'
):
raise
TypeError
(
'Padding parameters must be ints.'
)
# If the input shape are broadcastable we can have 0 in the output shape
broad
=
x
.
broadcastable
[:
-
nd
]
+
(
False
,)
*
nd
out
=
tensor
.
TensorType
(
eval_point
.
dtype
,
broad
)
return
gof
.
Apply
(
self
,
[
x
,
eval_point
,
ws
,
stride
,
pad
],
[
out
()])
def
perform
(
self
,
node
,
inp
,
out
):
x
,
ex
,
ws
,
stride
,
pad
=
inp
z
,
=
out
nd
=
self
.
ndim
assert
ws
.
shape
==
stride
.
shape
==
pad
.
shape
==
(
nd
,)
if
len
(
x
.
shape
)
<
nd
:
raise
NotImplementedError
(
'Pool requires input with {} or more dimensions'
.
format
(
nd
))
z_shape
=
Pool
.
out_shape
(
x
.
shape
,
ws
,
self
.
ignore_border
,
stride
,
pad
,
nd
)
if
not
self
.
ignore_border
:
assert
all
(
z
>
0
for
z
in
z_shape
[
-
nd
:])
if
(
z
[
0
]
is
None
)
or
(
z
[
0
]
.
shape
!=
z_shape
):
z
[
0
]
=
numpy
.
empty
(
z_shape
,
dtype
=
x
.
dtype
)
zz
=
z
[
0
]
# size of pooling output
pool_out_shp
=
zz
.
shape
[
-
nd
:]
img_shp
=
tuple
(
x
.
shape
[
-
nd
+
i
]
+
2
*
pad
[
i
]
for
i
in
xrange
(
nd
))
inc_pad
=
self
.
mode
==
'average_inc_pad'
# pad the image and the eval point
if
max
(
pad
)
!=
0
:
y
=
numpy
.
zeros
(
x
.
shape
[:
-
nd
]
+
img_shp
,
dtype
=
x
.
dtype
)
y
[(
slice
(
None
),)
*
(
len
(
x
.
shape
)
-
nd
)
+
tuple
(
slice
(
pad
[
i
],
img_shp
[
i
]
-
pad
[
i
])
for
i
in
xrange
(
nd
))]
=
x
ey
=
numpy
.
zeros
(
ex
.
shape
[:
-
nd
]
+
img_shp
,
dtype
=
ex
.
dtype
)
ey
[(
slice
(
None
),)
*
(
len
(
ex
.
shape
)
-
nd
)
+
tuple
(
slice
(
pad
[
i
],
img_shp
[
i
]
-
pad
[
i
])
for
i
in
xrange
(
nd
))]
=
ex
else
:
y
=
x
ey
=
ex
# precompute the region boundaries for each dimension
region_slices
=
[[]
for
i
in
xrange
(
nd
)]
for
i
in
xrange
(
nd
):
for
j
in
xrange
(
pool_out_shp
[
i
]):
start
=
j
*
stride
[
i
]
end
=
builtins
.
min
(
start
+
ws
[
i
],
img_shp
[
i
])
if
not
inc_pad
:
start
=
builtins
.
max
(
start
,
pad
[
i
])
end
=
builtins
.
min
(
end
,
img_shp
[
i
]
-
pad
[
i
])
region_slices
[
i
]
.
append
(
slice
(
start
,
end
))
# iterate over non-pooling dimensions
for
k
in
numpy
.
ndindex
(
*
x
.
shape
[:
-
nd
]):
zzk
=
zz
[
k
]
yk
=
y
[
k
]
eyk
=
ey
[
k
]
# iterate over pooling regions
for
r
in
numpy
.
ndindex
(
*
pool_out_shp
):
# current slice in padded input
ykslice
=
yk
[[
region_slices
[
i
][
r
[
i
]]
for
i
in
xrange
(
nd
)]]
# current slice in eval points
eykslice
=
eyk
[[
region_slices
[
i
][
r
[
i
]]
for
i
in
xrange
(
nd
)]]
# indices of maximum
idx
=
numpy
.
unravel_index
(
numpy
.
argmax
(
ykslice
),
ykslice
.
shape
)
zzk
[
r
]
=
eykslice
[
idx
]
def
c_headers
(
self
):
headers
=
[
'<algorithm>'
]
headers
+=
super
(
MaxPoolRop
,
self
)
.
c_headers
()
return
headers
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
if
self
.
mode
!=
'max'
:
raise
theano
.
gof
.
utils
.
MethodNotDefined
()
x
,
ex
,
ws
,
stride
,
pad
=
inp
z
,
=
out
nd
=
self
.
ndim
total_ndim
=
node
.
inputs
[
0
]
.
ndim
non_pool_ndim
=
total_ndim
-
nd
fail
=
sub
[
'fail'
]
ignore_border
=
int
(
self
.
ignore_border
)
if
self
.
openmp
:
# run in parallel over each pooling block
omp_parallel
=
'#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector, eval_collector) schedule(static)'
else
:
omp_parallel
=
''
ccode
=
"""
int typenum = PyArray_ObjectType((PyObject*)
%(x)
s, 0);
if(PyArray_NDIM(
%(x)
s)!=
%(total_ndim)
s)
{
PyErr_SetString(PyExc_ValueError, "x must be a
%(total_ndim)
sD ndarray");
%(fail)
s;
}
if(PyArray_NDIM(
%(ex)
s)!=
%(total_ndim)
s)
{
PyErr_SetString(PyExc_ValueError, "eval_point must be a
%(total_ndim)
sD ndarray");
%(fail)
s;
}
if(PyArray_DIM(
%(ws)
s, 0)!=
%(nd)
s)
{
PyErr_SetString(PyExc_ValueError, "ws must be a vector of size
%(nd)
s");
%(fail)
s;
}
if(PyArray_DIM(
%(stride)
s, 0)!=
%(nd)
s)
{
PyErr_SetString(PyExc_ValueError, "stride must be a vector of size
%(nd)
s");
%(fail)
s;
}
if(PyArray_DIM(
%(pad)
s, 0)!=
%(nd)
s)
{
PyErr_SetString(PyExc_ValueError, "pad must be a vector of size
%(nd)
s");
%(fail)
s;
}
int z[
%(nd)
s]; // shape of the output
int r[
%(nd)
s]; // shape of the padded_input
int ws[
%(nd)
s];
int st[
%(nd)
s];
int pd[
%(nd)
s];
int nonzero_padding;
nonzero_padding = 0;
for (int i=0; i<
%(nd)
s; i++)
{
ws[i] = *((npy_intp*)PyArray_GETPTR1(
%(ws)
s, i));
st[i] = *((npy_intp*)PyArray_GETPTR1(
%(stride)
s, i));
pd[i] = *((npy_intp*)PyArray_GETPTR1(
%(pad)
s, i));
r[i] = PyArray_DIMS(
%(x)
s)[
%(non_pool_ndim)
s + i] + 2 * pd[i];
if (pd[i]>0)
nonzero_padding = 1;
}
if (!
%(ignore_border)
s && nonzero_padding)
{
PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False");
%(fail)
s;
}
if (
%(ignore_border)
s)
{
for (int i=0; i<
%(nd)
s; i++)
{
// '/' in C is different from '/' in python
if (r[i] - ws[i] < 0)
{
z[i] = 0;
}
else
{
z[i] = (r[i] - ws[i]) / st[i] + 1;
}
}
}
else
{
for (int i=0; i<
%(nd)
s; i++)
{
// decide how many rows/cols the output has
if (st[i] >= ws[i])
{
z[i] = (r[i] - 1) / st[i] + 1;
}
else
{
z[i] = std::max(0, (r[i] - 1 - ws[i] + st[i]) / st[i]) + 1;
}
assert(z[i] > 0);
}
}
// memory allocation of z if necessary
int mem_nec;
mem_nec = 0;
if ((!
%(z)
s) || *PyArray_DIMS(
%(z)
s)!=
%(total_ndim)
s)
{
mem_nec = 1;
}
if (!mem_nec)
{
for (int i=0; i<
%(non_pool_ndim)
s; i++)
{
if (PyArray_DIMS(
%(z)
s)[i] != PyArray_DIMS(
%(x)
s)[i])
{
mem_nec = 1;
break;
}
}
}
if (!mem_nec)
{
for (int i=0; i<
%(nd)
s; i++)
{
if (PyArray_DIMS(
%(z)
s)[
%(non_pool_ndim)
s + i] != z[i])
{
mem_nec = 1;
break;
}
}
}
if (mem_nec)
{
if (
%(z)
s) Py_XDECREF(
%(z)
s);
npy_intp dims[
%(total_ndim)
s];
for (int i=0; i<
%(non_pool_ndim)
s; i++)
{
dims[i] = PyArray_DIMS(
%(x)
s)[i];
}
for (int i=0; i<
%(nd)
s; i++)
{
dims[
%(non_pool_ndim)
s + i] = z[i];
}
//TODO: zeros not necessary
%(z)
s = (PyArrayObject*) PyArray_ZEROS(
%(total_ndim)
s, dims, typenum,0);
}
// initialize temp var for the value in a region
dtype_
%(x)
s collector;
dtype_
%(ex)
s eval_collector;
int z_prod;
// do not run if any z[i] is zero
z_prod = 1;
for (int i=0; i<
%(nd)
s; i++)
{
z_prod *= z[i];
}
if (z_prod)
{
// will be used to hold start and end index of a region
int r_st[
%(nd)
s];
int r_end[
%(nd)
s];
// index for iterating over the pooling regions
int r_idx[
%(nd)
s];
// placeholder for PyArray indexing (output)
npy_intp o_idx[
%(total_ndim)
s];
// placeholder for PyArray indexing (input)
npy_intp i_idx[
%(total_ndim)
s];
// loop over non-pooling dimensions
int non_pooling_prod = 1;
for (int i=0; i<
%(non_pool_ndim)
s; i++)
{
non_pooling_prod *= PyArray_DIMS(
%(x)
s)[i];
}
%(omp_parallel)
s
// first loop over non-pooling dimensions
for (int t=0; t<non_pooling_prod; t++)
{
// compute the non-pooling index in each dimension
if (
%(non_pool_ndim)
s!=0)
{
o_idx[0] = t;
i_idx[0] = t;
for (int i=1; i<
%(non_pool_ndim)
s; i++)
{
o_idx[i] = o_idx[i - 1] / PyArray_DIMS(
%(x)
s)[i - 1];
o_idx[i - 1] = o_idx[i - 1]
%%
PyArray_DIMS(
%(x)
s)[i - 1];
i_idx[i] = o_idx[i];
i_idx[i - 1] = o_idx[i - 1];
}
}
// then loop over each region in each pooling dimension
"""
for
i
in
xrange
(
nd
):
ccode
+=
"""
for (r_idx[
%(i)
s]=0; r_idx[
%(i)
s] < z[
%(i)
s]; r_idx[
%(i)
s]++) {
r_st[
%(i)
s] = r_idx[
%(i)
s] * st[
%(i)
s];
r_end[
%(i)
s] = r_st[
%(i)
s] + ws[
%(i)
s];
// skip the padding
r_st[
%(i)
s] = r_st[
%(i)
s] < pd[
%(i)
s] ? pd[
%(i)
s] : r_st[
%(i)
s];
r_end[
%(i)
s] = r_end[
%(i)
s] > (r[
%(i)
s] - pd[
%(i)
s]) ? r[
%(i)
s] - pd[
%(i)
s] : r_end[
%(i)
s];
// from padded_img space to img space
r_st[
%(i)
s] -= pd[
%(i)
s];
r_end[
%(i)
s] -= pd[
%(i)
s];
// handle the case where no padding, ignore border is True
if (
%(ignore_border)
s)
{
r_end[
%(i)
s] = r_end[
%(i)
s] > r[
%(i)
s] ? r[
%(i)
s] : r_end[
%(i)
s];
}
// use the index to find the correct position in the output
o_idx[
%(non_pool_ndim)
s +
%(i)
s] = r_idx[
%(i)
s];
"""
%
dict
(
i
=
i
,
ignore_border
=
ignore_border
,
non_pool_ndim
=
non_pool_ndim
)
ccode
+=
"""
// get a pointer to the correct position in the output
dtype_
%(z)
s * z;
if (
%(total_ndim)
s == 4)
z = ((dtype_
%(z)
s*)(PyArray_GETPTR4(
%(z)
s, o_idx[0], o_idx[1], o_idx[2], o_idx[3])));
else
z = ((dtype_
%(z)
s*)(PyArray_GetPtr(
%(z)
s, o_idx)));
"""
for
i
in
xrange
(
nd
):
ccode
+=
"""
// set the first index of dimension
%(i)
s
i_idx[
%(non_pool_ndim)
s +
%(i)
s] = r_st[
%(i)
s];
"""
%
dict
(
i
=
i
,
non_pool_ndim
=
non_pool_ndim
)
ccode
+=
"""
// use the first element as the initial value of collector
if (
%(total_ndim)
s == 4) {
collector = ((dtype_
%(x)
s*)(PyArray_GETPTR4(
%(x)
s,i_idx[0],i_idx[1],i_idx[2],i_idx[3])))[0];
eval_collector = ((dtype_
%(ex)
s*)(PyArray_GETPTR4(
%(ex)
s,i_idx[0],i_idx[1],i_idx[2],i_idx[3])))[0];
} else {
collector = ((dtype_
%(x)
s*)(PyArray_GetPtr(
%(x)
s,i_idx)))[0];
eval_collector = ((dtype_
%(ex)
s*)(PyArray_GetPtr(
%(ex)
s,i_idx)))[0];
}
"""
for
i
in
xrange
(
nd
):
ccode
+=
"""
// go through the pooled region in the unpadded input
for(int m
%(i)
s=r_st[
%(i)
s]; m
%(i)
s<r_end[
%(i)
s]; m
%(i)
s++)
{
i_idx[
%(non_pool_ndim)
s +
%(i)
s] = m
%(i)
s;
"""
%
dict
(
i
=
i
,
non_pool_ndim
=
non_pool_ndim
)
ccode
+=
"""
// update maximum
dtype_
%(x)
s a;
dtype_
%(ex)
s ea;
if (
%(total_ndim)
s == 4) {
a = ((dtype_
%(x)
s*)(PyArray_GETPTR4(
%(x)
s,i_idx[0],i_idx[1],i_idx[2],i_idx[3])))[0];
ea = ((dtype_
%(ex)
s*)(PyArray_GETPTR4(
%(ex)
s,i_idx[0],i_idx[1],i_idx[2],i_idx[3])))[0];
}
else {
a = ((dtype_
%(x)
s*)(PyArray_GetPtr(
%(x)
s,i_idx)))[0];
ea = ((dtype_
%(ex)
s*)(PyArray_GetPtr(
%(ex)
s,i_idx)))[0];
}
if (a > collector) {
collector = a;
eval_collector = ea;
}
"""
for
i
in
xrange
(
nd
):
ccode
+=
"""
} // for loop over region
"""
ccode
+=
"""
z[0] = eval_collector;
"""
for
i
in
xrange
(
nd
):
ccode
+=
"""
} // loop over pooling dimension
"""
ccode
+=
"""
} // for loop over non-pooling dimensions
} // if z_prod
"""
return
ccode
%
locals
()
def
c_code_cache_version
(
self
):
return
(
0
,
self
.
openmp
)
theano/tests/test_rop.py
浏览文件 @
7b5919e9
...
...
@@ -17,10 +17,12 @@ from theano.tests import unittest_tools as utt
from
theano
import
function
import
theano
from
theano
import
tensor
import
itertools
import
numpy
from
theano.gof
import
Op
,
Apply
from
theano.gradient
import
grad_undefined
from
theano.tests.unittest_tools
import
SkipTest
from
theano.tensor.signal.pool
import
Pool
from
theano.tensor.nnet
import
conv
,
conv2d
'''
...
...
@@ -255,6 +257,47 @@ class test_RopLop(RopLop_checker):
self
.
x
[:
4
]
.
dimshuffle
(
'x'
,
0
),
0
)
.
sum
(
axis
=
1
),
(
1
,))
def
test_downsample
(
self
):
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
# ws, shp
examples
=
(
((
2
,),
(
16
,)),
((
2
,),
(
4
,
16
,)),
((
2
,),
(
4
,
2
,
16
,)),
((
1
,
1
),
(
4
,
2
,
16
,
16
)),
((
2
,
2
),
(
4
,
2
,
16
,
16
)),
((
3
,
3
),
(
4
,
2
,
16
,
16
)),
((
3
,
2
),
(
4
,
2
,
16
,
16
)),
((
3
,
2
,
2
),
(
3
,
2
,
16
,
16
,
16
)),
((
2
,
3
,
2
),
(
3
,
2
,
16
,
16
,
16
)),
((
2
,
2
,
3
),
(
3
,
2
,
16
,
16
,
16
)),
((
2
,
2
,
3
,
2
),
(
3
,
2
,
6
,
6
,
6
,
5
)),
)
for
example
,
ignore_border
in
itertools
.
product
(
examples
,
[
True
,
False
]):
(
ws
,
shp
)
=
example
vx
=
rng
.
rand
(
*
shp
)
vex
=
rng
.
rand
(
*
shp
)
x
=
theano
.
shared
(
vx
)
ex
=
theano
.
shared
(
vex
)
maxpool_op
=
Pool
(
ignore_border
,
ndim
=
len
(
ws
))
a_pooled
=
maxpool_op
(
x
,
ws
)
.
flatten
()
yv
=
tensor
.
Rop
(
a_pooled
,
x
,
ex
)
mode
=
None
if
theano
.
config
.
mode
==
"FAST_COMPILE"
:
mode
=
"FAST_RUN"
rop_f
=
function
([],
yv
,
on_unused_input
=
'ignore'
,
mode
=
mode
)
sy
,
_
=
theano
.
scan
(
lambda
i
,
y
,
x
,
v
:
(
tensor
.
grad
(
y
[
i
],
x
)
*
v
)
.
sum
(),
sequences
=
tensor
.
arange
(
a_pooled
.
shape
[
0
]),
non_sequences
=
[
a_pooled
,
x
,
ex
])
scan_f
=
function
([],
sy
,
on_unused_input
=
'ignore'
,
mode
=
mode
)
v1
=
rop_f
()
v2
=
scan_f
()
assert
numpy
.
allclose
(
v1
,
v2
),
(
"Rop mismatch:
%
s
%
s"
%
(
v1
,
v2
))
def
test_conv
(
self
):
for
conv_op
in
[
conv
.
conv2d
,
conv2d
]:
for
border_mode
in
[
'valid'
,
'full'
]:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论