Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
289c3bd4
提交
289c3bd4
authored
8月 15, 2016
作者:
Gijs van Tulder
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Introduce AbstractConv3D and related changes.
Add abstract convolution classes, reuse this for 2D and 3D.
上级
d3fb7189
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
1346 行增加
和
287 行删除
+1346
-287
opt.py
theano/sandbox/cuda/opt.py
+3
-3
__init__.py
theano/tensor/nnet/__init__.py
+1
-0
abstract_conv.py
theano/tensor/nnet/abstract_conv.py
+785
-204
opt.py
theano/tensor/nnet/opt.py
+116
-2
test_abstract_conv.py
theano/tensor/nnet/tests/test_abstract_conv.py
+441
-78
没有找到文件。
theano/sandbox/cuda/opt.py
浏览文件 @
289c3bd4
...
...
@@ -87,7 +87,7 @@ from theano.tensor import slinalg
from
theano.tensor.nnet.Conv3D
import
Conv3D
from
theano.tests.breakpoint
import
PdbBreakpoint
from
theano.tensor.nnet.abstract_conv
import
(
BaseAbstractConv
2d
,
from
theano.tensor.nnet.abstract_conv
import
(
BaseAbstractConv
,
AbstractConv2d
,
AbstractConv2d_gradWeights
,
AbstractConv2d_gradInputs
)
...
...
@@ -2736,7 +2736,7 @@ def local_conv2d_gpu_conv(node):
if
isinstance
(
node
.
op
,
GpuFromHost
):
host_input
=
node
.
inputs
[
0
]
if
host_input
.
owner
and
isinstance
(
host_input
.
owner
.
op
,
BaseAbstractConv
2d
):
BaseAbstractConv
):
conv
=
host_input
.
owner
.
op
inps
=
list
(
host_input
.
owner
.
inputs
)
...
...
@@ -2749,7 +2749,7 @@ def local_conv2d_gpu_conv(node):
out
.
tag
.
values_eq_approx
=
values_eq_approx_high_tol
return
[
out
]
if
isinstance
(
node
.
op
,
BaseAbstractConv
2d
):
if
isinstance
(
node
.
op
,
BaseAbstractConv
):
# conv(host_from_gpu) -> host_from_gpu(gpu_conv)
inp1
=
node
.
inputs
[
0
]
inp2
=
node
.
inputs
[
1
]
...
...
theano/tensor/nnet/__init__.py
浏览文件 @
289c3bd4
...
...
@@ -32,6 +32,7 @@ from .bn import batch_normalization
import
warnings
from
.abstract_conv
import
conv2d
as
abstract_conv2d
from
.abstract_conv
import
conv3d
as
abstract_conv3d
def
conv2d
(
input
,
filters
,
input_shape
=
None
,
filter_shape
=
None
,
...
...
theano/tensor/nnet/abstract_conv.py
浏览文件 @
289c3bd4
...
...
@@ -20,7 +20,7 @@ import numpy
import
numpy
as
np
try
:
from
scipy.signal.signaltools
import
_valfrommode
,
_bvalfromboundary
from
scipy.signal.signaltools
import
_valfrommode
,
_bvalfromboundary
,
convolve
from
scipy.signal.sigtools
import
_convolve2d
imported_scipy_signal
=
True
except
ImportError
:
...
...
@@ -163,6 +163,33 @@ def conv2d(input,
return
conv_op
(
input
,
filters
)
def
conv3d
(
input
,
filters
,
input_shape
=
None
,
filter_shape
=
None
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
,
1
)):
"""This function will build the symbolic graph for convolving a mini-batch of a
stack of 3D inputs with a set of 3D filters. The implementation is modelled
after Convolutional Neural Networks (CNN).
TODO
Refer to :func:`nnet.conv3d <theano.tensor.nnet.conv2d>` for a more detailed documentation.
"""
input
=
as_tensor_variable
(
input
)
filters
=
as_tensor_variable
(
filters
)
conv_op
=
AbstractConv3d
(
imshp
=
input_shape
,
kshp
=
filter_shape
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
return
conv_op
(
input
,
filters
)
def
conv2d_grad_wrt_inputs
(
output_grad
,
filters
,
input_shape
,
...
...
@@ -298,6 +325,141 @@ def conv2d_grad_wrt_inputs(output_grad,
return
grad_input_op
(
filters
,
output_grad
,
input_shape
[
-
2
:])
def
conv3d_grad_wrt_inputs
(
output_grad
,
filters
,
input_shape
,
filter_shape
=
None
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
,
1
)):
"""Compute conv output gradient w.r.t its inputs
This function builds the symbolic graph for getting the
gradient of the output of a convolution (namely output_grad)
w.r.t the input of the convolution, given a set of 3D filters
used by the convolution, such that the output_grad is upsampled
to the input_shape.
Parameters
----------
output_grad : symbolic 5D tensor
mini-batch of feature map stacks, of shape (batch size, input
channels, input depth, input rows, input columns). This is the
tensor that will be upsampled or the output gradient of the
convolution whose gradient will be taken with respect to the
input of the convolution.
filters : symbolic 5D tensor
set of filters used in CNN layer of shape (output channels,
input channels, filter depth, filter rows, filter columns).
See the optional parameter ``filter_shape``.
input_shape : [None/int/Constant] * 2 + [Tensor/int/Constant] * 2
The shape of the input (upsampled) parameter.
A tuple/list of len 5, with the first two dimensions
being None or int or Constant and the last three dimensions being
Tensor or int or Constant.
Not Optional, since given the output_grad shape
and the subsample values, multiple input_shape may be
plausible.
filter_shape : None or [None/int/Constant] * 5
The shape of the filters parameter. None or a tuple/list of len 5.
Optional, possibly used to choose an optimal implementation.
You can give ``None`` for any element of the list to specify that
this element is not known at compile time.
border_mode : str, int or tuple of three int
Either of the following:
``'valid'``
apply filter wherever it completely overlaps with the
input. Generates output of shape: input shape - filter
shape + 1
``'full'``
apply filter wherever it partly overlaps with the input.
Generates output of shape: input shape + filter shape - 1
``'half'``
pad input with a symmetric border of ``filter // 2``,
then perform a valid convolution. For filters with an odd
number of slices, rows and columns, this leads to the output
shape being equal to the input shape. It is known as 'same'
elsewhere.
``int``
pad input with a symmetric border of zeros of the given
width, then perform a valid convolution.
``(int1, int2, int3)``
pad input with a symmetric border of ``int1``, ``int2`` and
``int3`` columns, then perform a valid convolution.
subsample : tuple of len 3
The subsampling used in the forward pass. Also called strides
elsewhere.
filter_flip : bool
If ``True``, will flip the filter x, y and z dimensions before
sliding them over the input. This operation is normally
referred to as a convolution, and this is the default. If
``False``, the filters are not flipped and the operation is
referred to as a cross-correlation.
filter_dilation : tuple of len 3
The filter dilation used in the forward pass.
Also known as input striding.
Returns
-------
symbolic 5D tensor
set of feature maps generated by convolutional layer. Tensor
is of shape (batch size, output channels, output depth,
output rows, output columns)
Notes
-----
:note: If cuDNN is available, it will be used on the
GPU. Otherwise, it is the *CorrMM* convolution that will be used
"caffe style convolution".
:note: This is only supported in Theano 0.8 or the development
version until it is released.
"""
filters
=
as_tensor_variable
(
filters
)
output_grad
=
as_tensor_variable
(
output_grad
)
# checking the type of input_shape
for
dim
in
[
0
,
1
]:
assert
isinstance
(
input_shape
[
dim
],
(
theano
.
tensor
.
TensorConstant
,
integer_types
,
type
(
None
)))
for
dim
in
[
2
,
3
,
4
]:
assert
isinstance
(
input_shape
[
dim
],
(
theano
.
tensor
.
TensorVariable
,
theano
.
tensor
.
TensorConstant
,
integer_types
))
# checking the type of filter_shape
if
filter_shape
is
not
None
:
for
dim
in
[
0
,
1
,
2
,
3
,
4
]:
assert
isinstance
(
filter_shape
[
dim
],
(
theano
.
tensor
.
TensorConstant
,
integer_types
,
type
(
None
)))
# setting the last three dimensions of input_shape to None, if
# the type of these dimensions is TensorVariable.
numerical_input_shape
=
list
(
input_shape
)
for
dim
in
[
2
,
3
,
4
]:
if
isinstance
(
input_shape
[
dim
],
theano
.
tensor
.
TensorVariable
):
numerical_input_shape
[
dim
]
=
None
grad_input_op
=
AbstractConv3d_gradInputs
(
imshp
=
numerical_input_shape
,
kshp
=
filter_shape
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
return
grad_input_op
(
filters
,
output_grad
,
input_shape
[
-
3
:])
def
conv2d_grad_wrt_weights
(
input
,
output_grad
,
filter_shape
,
...
...
@@ -425,6 +587,132 @@ def conv2d_grad_wrt_weights(input,
return
gradWeight_op
(
input
,
output_grad
,
filter_shape
[
-
2
:])
def
conv3d_grad_wrt_weights
(
input
,
output_grad
,
filter_shape
,
input_shape
=
None
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
,
1
)):
"""Compute conv output gradient w.r.t its weights
This function will build the symbolic graph for getting the
gradient of the output of a convolution (output_grad) w.r.t its weights.
Parameters
----------
input : symbolic 5D tensor
mini-batch of feature map stacks, of shape (batch size, input
channels, input depth, input rows, input columns). This is the input
of the convolution in the forward pass.
output_grad : symbolic 5D tensor
mini-batch of feature map stacks, of shape (batch size, input
channels, input depth, input rows, input columns). This is the
gradient of the output of convolution.
filter_shape : [None/int/Constant] * 2 + [Tensor/int/Constant] * 2
The shape of the filter parameter. A tuple/list of len 5, with the
first two dimensions being None or int or Constant and the last three
dimensions being Tensor or int or Constant.
Not Optional, since given the output_grad shape and
the input_shape, multiple filter_shape may be plausible.
input_shape : None or [None/int/Constant] * 5
The shape of the input parameter. None or a tuple/list of len 5.
Optional, possibly used to choose an optimal implementation.
You can give ``None`` for any element of the list to specify
that this element is not known at compile time.
border_mode : str, int or tuple of two ints
Either of the following:
``'valid'``
apply filter wherever it completely overlaps with the
input. Generates output of shape: input shape - filter
shape + 1
``'full'``
apply filter wherever it partly overlaps with the input.
Generates output of shape: input shape + filter shape - 1
``'half'``
pad input with a symmetric border of ``filter rows // 2``
rows and ``filter columns // 2`` columns, then perform a
valid convolution. For filters with an odd number of rows
and columns, this leads to the output shape being equal to
the input shape. It is known as 'same' elsewhere.
``int``
pad input with a symmetric border of zeros of the given
width, then perform a valid convolution.
``(int1, int2, int3)``
pad input with a symmetric border of ``int1``, ``int2`` and
``int3``, then perform a valid convolution.
subsample : tuple of len 3
The subsampling used in the forward pass of the convolutional
operation. Also called strides elsewhere.
filter_flip : bool
If ``True``, will flip the filters before sliding them over the
input. This operation is normally referred to as a convolution,
and this is the default. If ``False``, the filters are not
flipped and the operation is referred to as a cross-correlation.
filter_dilation : tuple of len 3
The filter dilation used in the forward pass.
Also known as input striding.
Returns
-------
symbolic 5D tensor
set of feature maps generated by convolutional layer. Tensor
is of shape (batch size, output channels, output time, output
rows, output columns)
Notes
-----
:note: If cuDNN is available, it will be used on the
GPU. Otherwise, it is the *CorrMM* convolution that will be used
"caffe style convolution".
:note: This is only supported in Theano 0.8 or the development
version until it is released.
"""
input
=
as_tensor_variable
(
input
)
output_grad
=
as_tensor_variable
(
output_grad
)
# checking the type of filter_shape
for
dim
in
[
0
,
1
]:
assert
isinstance
(
filter_shape
[
dim
],
(
theano
.
tensor
.
TensorConstant
,
integer_types
,
type
(
None
)))
for
dim
in
[
2
,
3
,
4
]:
assert
isinstance
(
filter_shape
[
dim
],
(
theano
.
tensor
.
TensorVariable
,
theano
.
tensor
.
TensorConstant
,
integer_types
))
# checking the type of input_shape
if
input_shape
is
not
None
:
for
dim
in
[
0
,
1
,
2
,
3
,
4
]:
assert
isinstance
(
input_shape
[
dim
],
(
theano
.
tensor
.
TensorConstant
,
integer_types
,
type
(
None
)))
# setting the last three dimensions of filter_shape to None, if
# the type of these dimensions is TensorVariable.
numerical_filter_shape
=
list
(
filter_shape
)
for
dim
in
[
2
,
3
,
4
]:
if
isinstance
(
filter_shape
[
dim
],
theano
.
tensor
.
TensorVariable
):
numerical_filter_shape
[
dim
]
=
None
gradWeight_op
=
AbstractConv3d_gradWeights
(
imshp
=
input_shape
,
kshp
=
numerical_filter_shape
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
return
gradWeight_op
(
input
,
output_grad
,
filter_shape
[:
-
3
])
def
bilinear_kernel_2D
(
ratio
,
normalize
=
True
):
"""Compute 2D kernel for bilinear upsampling
...
...
@@ -608,45 +896,46 @@ def bilinear_upsampling(input,
row
*
ratio
,
col
*
ratio
))
class
BaseAbstractConv
2d
(
Op
):
class
BaseAbstractConv
(
Op
):
"""Base class for AbstractConv
Define an abstract convolution op that will be replaced with the
appropriate implementation
Parameters
----------
imshp: None, tuple/list of len 4 of int or Constant variable
convdim: The number of convolution dimensions (2 or 3).
imshp: None, tuple/list of len ``(2 + convdim)`` of int or Constant variable
The shape of the input parameter.
Optional, possibly used to choose an optimal implementation.
You can give ``None`` for any element of the list to specify that this
element is not known at compile time.
imshp is defined w.r.t the forward conv.
kshp: None, tuple/list of len
4
of int or Constant variable
kshp: None, tuple/list of len
``(2 + convdim)``
of int or Constant variable
The shape of the filters parameter.
Optional, possibly used to choose an optimal implementation.
You can give ``None`` for any element of the list to specify that this
element is not known at compile time.
kshp is defined w.r.t the forward conv.
border_mode: str, int or tuple of
two int
border_mode: str, int or tuple of
``convdim`` ints
Either of the following:
``'valid'``: apply filter wherever it completely overlaps with the
input. Generates output of shape: input shape - filter shape + 1
``'full'``: apply filter wherever it partly overlaps with the input.
Generates output of shape: input shape + filter shape - 1
``'half'``: pad input with a symmetric border of ``filter
rows
// 2``
rows and ``filter columns // 2`` columns, then perform a valid
convolution. For filters with an odd number of rows and columns, this
leads to the output
shape being equal to the input shape.
``'half'``: pad input with a symmetric border of ``filter
size
// 2``
in each convolution dimension, then perform a valid convolution.
For filters with an odd filter size, this leads to the output
shape being equal to the input shape.
``int``: pad input with a symmetric border of zeros of the given
width, then perform a valid convolution.
``(int1, int2)``: pad input with a symmetric border of ``int1`` rows
and ``int2`` columns, then perform a valid convolution.
``(int1, int2)``: (for 2D) pad input with a symmetric border of ``int1``,
``int2``, then perform a valid convolution.
``(int1, int2, int3)``: (for 3D) pad input with a symmetric border of
``int1``, ``int2`` and ``int3``, then perform a valid convolution.
subsample: tuple of len
2
subsample: tuple of len
``convdim``
Factor by which to subsample the output.
Also called strides elsewhere.
...
...
@@ -657,34 +946,46 @@ class BaseAbstractConv2d(Op):
are not flipped and the operation is referred to as a
cross-correlation.
filter_dilation: tuple of len
2
filter_dilation: tuple of len
``convdim``
Factor by which to subsample (stride) the input.
Also called dilation factor.
"""
check_broadcast
=
False
__props__
=
(
'border_mode'
,
'subsample'
,
'filter_flip'
,
__props__
=
(
'
convdim'
,
'
border_mode'
,
'subsample'
,
'filter_flip'
,
'imshp'
,
'kshp'
,
'filter_dilation'
)
def
__init__
(
self
,
def
__init__
(
self
,
convdim
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
)):
subsample
=
None
,
filter_flip
=
True
,
filter_dilation
=
None
):
self
.
convdim
=
convdim
if
convdim
not
in
(
2
,
3
):
raise
ValueError
(
'convolution dimension {} is not supported'
,
convdim
)
if
subsample
is
None
:
subsample
=
(
1
,)
*
convdim
if
filter_dilation
is
None
:
filter_dilation
=
(
1
,)
*
convdim
if
isinstance
(
border_mode
,
integer_types
):
border_mode
=
(
border_mode
,
border_mode
)
border_mode
=
(
border_mode
,
)
*
convdim
if
isinstance
(
border_mode
,
tuple
):
pad_h
,
pad_w
=
map
(
int
,
border_mode
)
border_mode
=
(
pad_h
,
pad_w
)
if
border_mode
==
(
0
,
0
):
if
len
(
border_mode
)
!=
convdim
:
raise
ValueError
(
'border mode must have exactly {} values, '
'but was {}'
.
format
(
convdim
))
border_mode
=
tuple
(
map
(
int
,
border_mode
))
if
border_mode
==
(
0
,)
*
convdim
:
border_mode
=
'valid'
if
not
((
isinstance
(
border_mode
,
tuple
)
and
min
(
border_mode
)
>=
0
)
or
border_mode
in
(
'valid'
,
'full'
,
'half'
)):
raise
ValueError
(
'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a
pair of
'
' integers'
.
format
(
border_mode
))
'"valid", "full", "half", an integer or a
tuple of {}
'
' integers'
.
format
(
border_mode
,
convdim
))
self
.
imshp
=
tuple
(
imshp
)
if
imshp
else
(
None
,)
*
4
self
.
imshp
=
tuple
(
imshp
)
if
imshp
else
(
None
,)
*
(
2
+
convdim
)
for
imshp_i
in
self
.
imshp
:
if
imshp_i
is
not
None
:
# Components of imshp should be constant or ints
...
...
@@ -696,7 +997,7 @@ class BaseAbstractConv2d(Op):
ValueError
(
"imshp should be None or a tuple of "
"constant int values"
),
sys
.
exc_info
()[
2
])
self
.
kshp
=
tuple
(
kshp
)
if
kshp
else
(
None
,)
*
4
self
.
kshp
=
tuple
(
kshp
)
if
kshp
else
(
None
,)
*
(
2
+
convdim
)
for
kshp_i
in
self
.
kshp
:
if
kshp_i
is
not
None
:
# Components of kshp should be constant or ints
...
...
@@ -711,36 +1012,41 @@ class BaseAbstractConv2d(Op):
self
.
border_mode
=
border_mode
self
.
filter_flip
=
filter_flip
if
len
(
subsample
)
!=
2
:
raise
ValueError
(
"subsample must have
two elements"
)
if
len
(
subsample
)
!=
convdim
:
raise
ValueError
(
"subsample must have
{} elements"
.
format
(
convdim
)
)
self
.
subsample
=
tuple
(
subsample
)
if
len
(
filter_dilation
)
!=
2
:
raise
ValueError
(
"filter_dilation must have
two elements"
)
if
len
(
filter_dilation
)
!=
convdim
:
raise
ValueError
(
"filter_dilation must have
{} elements"
.
format
(
convdim
)
)
self
.
filter_dilation
=
tuple
(
filter_dilation
)
def
flops
(
self
,
inp
,
outp
):
""" Useful with the hack in profiling to print the MFlops"""
# if the output shape is correct, then this gives the correct
# flops for any direction, sampling, padding, and border mode
inputs
,
filters
=
inp
outputs
,
=
outp
assert
inputs
[
1
]
==
filters
[
1
]
# nb mul and add by output pixel
flops
=
filters
[
2
]
*
filters
[
3
]
*
2
# nb flops by output image
flops
*=
outputs
[
2
]
*
outputs
[
3
]
# nb patch multiplied
flops
*=
inputs
[
1
]
*
filters
[
0
]
*
inputs
[
0
]
return
flops
def
do_constant_folding
(
self
,
node
):
# Disable constant folding since there is no implementation.
# This may change in the future.
return
False
def
conv2d
(
self
,
img
,
kern
,
mode
=
"valid"
,
dilation
=
(
1
,
1
)):
def
flops
(
self
,
inp
,
outp
):
""" Useful with the hack in profiling to print the MFlops"""
if
self
.
convdim
==
2
:
# if the output shape is correct, then this gives the correct
# flops for any direction, sampling, padding, and border mode
inputs
,
filters
=
inp
outputs
,
=
outp
assert
inputs
[
1
]
==
filters
[
1
]
# nb mul and add by output pixel
flops
=
filters
[
2
]
*
filters
[
3
]
*
2
# nb flops by output image
flops
*=
outputs
[
2
]
*
outputs
[
3
]
# nb patch multiplied
flops
*=
inputs
[
1
]
*
filters
[
0
]
*
inputs
[
0
]
return
flops
else
:
# TODO implement for convdim == 3
raise
NotImplementedError
(
'flops not implemented for convdim={}'
,
self
.
convdim
)
def
conv
(
self
,
img
,
kern
,
mode
=
"valid"
,
dilation
=
1
):
"""
Basic slow
python implementata
tion for DebugMode
Basic slow
Python 2D or 3D convolu
tion for DebugMode
"""
if
not
imported_scipy_signal
:
...
...
@@ -751,48 +1057,70 @@ class BaseAbstractConv2d(Op):
raise
ValueError
(
'invalid mode {}, which must be either '
'"valid" or "full"'
.
format
(
mode
))
if
isinstance
(
dilation
,
integer_types
):
dilation
=
(
dilation
,)
*
self
.
convdim
if
len
(
dilation
)
!=
self
.
convdim
:
raise
ValueError
(
'invalid dilation {}, expected {} values'
.
format
(
dilation
,
self
.
convdim
))
out_shape
=
get_conv_output_shape
(
img
.
shape
,
kern
.
shape
,
mode
,
[
1
,
1
]
,
dilation
)
mode
,
[
1
]
*
self
.
convdim
,
dilation
)
out
=
numpy
.
zeros
(
out_shape
,
dtype
=
img
.
dtype
)
dil_kern_shp
=
kern
.
shape
[:
-
2
]
+
((
kern
.
shape
[
-
2
]
-
1
)
*
dilation
[
0
]
+
1
,
(
kern
.
shape
[
-
1
]
-
1
)
*
dilation
[
1
]
+
1
)
dil_kern_shp
=
kern
.
shape
[:
-
self
.
convdim
]
+
tuple
(
(
kern
.
shape
[
-
self
.
convdim
+
i
]
-
1
)
*
dilation
[
i
]
+
1
for
i
in
range
(
self
.
convdim
))
dilated_kern
=
numpy
.
zeros
(
dil_kern_shp
,
dtype
=
kern
.
dtype
)
dilated_kern
[:,
:,
::
dilation
[
0
],
::
dilation
[
1
]]
=
kern
val
=
_valfrommode
(
mode
)
bval
=
_bvalfromboundary
(
'fill'
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'ignore'
,
numpy
.
ComplexWarning
)
dilated_kern
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
None
,
None
,
dilation
[
i
])
for
i
in
range
(
self
.
convdim
))
]
=
kern
if
self
.
convdim
==
2
:
val
=
_valfrommode
(
mode
)
bval
=
_bvalfromboundary
(
'fill'
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'ignore'
,
numpy
.
ComplexWarning
)
for
b
in
xrange
(
img
.
shape
[
0
]):
for
n
in
xrange
(
kern
.
shape
[
0
]):
for
im0
in
xrange
(
img
.
shape
[
1
]):
# some cast generates a warning here
out
[
b
,
n
,
...
]
+=
_convolve2d
(
img
[
b
,
im0
,
...
],
dilated_kern
[
n
,
im0
,
...
],
1
,
val
,
bval
,
0
)
elif
self
.
convdim
==
3
:
for
b
in
xrange
(
img
.
shape
[
0
]):
for
n
in
xrange
(
kern
.
shape
[
0
]):
for
im0
in
xrange
(
img
.
shape
[
1
]):
# some cast generates a warning here
out
[
b
,
n
,
...
]
+=
_convolve2d
(
img
[
b
,
im0
,
...
],
dilated_kern
[
n
,
im0
,
...
],
1
,
val
,
bval
,
0
)
out
[
b
,
n
,
...
]
+=
convolve
(
img
[
b
,
im0
,
...
],
dilated_kern
[
n
,
im0
,
...
],
mode
)
else
:
raise
NotImplementedError
(
'only 2D and 3D convolution are implemented'
)
return
out
class
AbstractConv
2d
(
BaseAbstractConv2d
):
class
AbstractConv
(
BaseAbstractConv
):
""" Abstract Op for the forward convolution.
Refer to :func:`BaseAbstractConv
2d <theano.tensor.nnet.abstract_conv.BaseAbstractConv2d
>`
Refer to :func:`BaseAbstractConv
<theano.tensor.nnet.abstract_conv.BaseAbstractConv
>`
for a more detailed documentation.
"""
def
__init__
(
self
,
convdim
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
)
,
subsample
=
None
,
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
)):
super
(
AbstractConv2d
,
self
)
.
__init__
(
imshp
,
kshp
,
border_mode
,
subsample
,
filter_flip
,
filter_dilation
)
filter_dilation
=
None
):
super
(
AbstractConv
,
self
)
.
__init__
(
convdim
=
convdim
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
def
make_node
(
self
,
img
,
kern
):
# Make sure both inputs are Variables with the same Type
...
...
@@ -804,14 +1132,13 @@ class AbstractConv2d(BaseAbstractConv2d):
broadcastable
=
kern
.
broadcastable
)
kern
=
ktype
.
filter_variable
(
kern
)
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be
4D tensor'
)
if
kern
.
type
.
ndim
!=
4
:
raise
TypeError
(
'kern must be
4D tensor'
)
if
img
.
type
.
ndim
!=
2
+
self
.
convdim
:
raise
TypeError
(
'img must be
%
dD tensor'
%
(
2
+
self
.
convdim
)
)
if
kern
.
type
.
ndim
!=
2
+
self
.
convdim
:
raise
TypeError
(
'kern must be
%
dD tensor'
%
(
2
+
self
.
convdim
)
)
broadcastable
=
[
img
.
broadcastable
[
0
],
kern
.
broadcastable
[
0
],
False
,
False
]
kern
.
broadcastable
[
0
]]
+
([
False
]
*
self
.
convdim
)
output
=
img
.
type
.
clone
(
broadcastable
=
broadcastable
)()
return
Apply
(
self
,
[
img
,
kern
],
[
output
])
...
...
@@ -819,8 +1146,8 @@ class AbstractConv2d(BaseAbstractConv2d):
img
,
kern
=
inp
img
=
numpy
.
asarray
(
img
)
kern
=
numpy
.
asarray
(
kern
)
dil_kernshp
=
((
kern
.
shape
[
2
]
-
1
)
*
self
.
filter_dilation
[
0
]
+
1
,
(
kern
.
shape
[
3
]
-
1
)
*
self
.
filter_dilation
[
1
]
+
1
)
dil_kernshp
=
tuple
((
kern
.
shape
[
2
+
i
]
-
1
)
*
self
.
filter_dilation
[
i
]
+
1
for
i
in
range
(
self
.
convdim
)
)
o
,
=
out_
mode
=
self
.
border_mode
...
...
@@ -828,25 +1155,30 @@ class AbstractConv2d(BaseAbstractConv2d):
mode
in
(
'valid'
,
'full'
,
'half'
)):
raise
ValueError
(
'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a
pair
of'
'"valid", "full", "half", an integer or a
tuple
of'
' integers'
.
format
(
mode
))
if
mode
==
"full"
:
mode
=
(
dil_kernshp
[
0
]
-
1
,
dil_kernshp
[
1
]
-
1
)
mode
=
tuple
(
dil_kernshp
[
i
]
-
1
for
i
in
range
(
self
.
convdim
)
)
elif
mode
==
"half"
:
mode
=
(
dil_kernshp
[
0
]
//
2
,
dil_kernshp
[
1
]
//
2
)
mode
=
tuple
(
dil_kernshp
[
i
]
//
2
for
i
in
range
(
self
.
convdim
)
)
if
isinstance
(
mode
,
tuple
):
pad
_h
,
pad_w
=
map
(
int
,
mode
)
pad
=
tuple
(
int
(
mode
[
i
])
for
i
in
range
(
self
.
convdim
)
)
mode
=
"valid"
new_img
=
numpy
.
zeros
((
img
.
shape
[
0
],
img
.
shape
[
1
],
img
.
shape
[
2
]
+
2
*
pad_h
,
img
.
shape
[
3
]
+
2
*
pad_w
),
dtype
=
img
.
dtype
)
new_img
[:,
:,
pad_h
:
img
.
shape
[
2
]
+
pad_h
,
pad_w
:
img
.
shape
[
3
]
+
pad_w
]
=
img
new_img
=
numpy
.
zeros
((
img
.
shape
[
0
],
img
.
shape
[
1
])
+
tuple
(
img
.
shape
[
i
+
2
]
+
2
*
pad
[
i
]
for
i
in
range
(
self
.
convdim
)),
dtype
=
img
.
dtype
)
new_img
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
pad
[
i
],
img
.
shape
[
i
+
2
]
+
pad
[
i
])
for
i
in
range
(
self
.
convdim
))]
=
img
img
=
new_img
if
not
self
.
filter_flip
:
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
conv_out
=
self
.
conv2d
(
img
,
kern
,
mode
=
"valid"
,
dilation
=
self
.
filter_dilation
)
conv_out
=
conv_out
[:,
:,
::
self
.
subsample
[
0
],
::
self
.
subsample
[
1
]]
kern
=
kern
[(
slice
(
None
),
slice
(
None
))
+
(
slice
(
None
,
None
,
-
1
),)
*
self
.
convdim
]
conv_out
=
self
.
conv
(
img
,
kern
,
mode
=
"valid"
,
dilation
=
self
.
filter_dilation
)
conv_out
=
conv_out
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
None
,
None
,
self
.
subsample
[
i
])
for
i
in
range
(
self
.
convdim
))]
o
[
0
]
=
node
.
outputs
[
0
]
.
type
.
filter
(
conv_out
)
...
...
@@ -861,6 +1193,42 @@ class AbstractConv2d(BaseAbstractConv2d):
rval
+=
self
.
make_node
(
inputs
[
0
],
eval_points
[
1
])
.
outputs
[
0
]
return
[
rval
]
def
infer_shape
(
self
,
node
,
input_shapes
):
imshp
=
input_shapes
[
0
]
kshp
=
input_shapes
[
1
]
# replace symbolic shapes with known constant shapes
if
self
.
imshp
is
not
None
:
imshp
=
[
imshp
[
i
]
if
self
.
imshp
[
i
]
is
None
else
self
.
imshp
[
i
]
for
i
in
range
(
2
+
self
.
convdim
)]
if
self
.
kshp
is
not
None
:
kshp
=
[
kshp
[
i
]
if
self
.
kshp
[
i
]
is
None
else
self
.
kshp
[
i
]
for
i
in
range
(
2
+
self
.
convdim
)]
res
=
get_conv_output_shape
(
imshp
,
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_dilation
)
return
[
res
]
class
AbstractConv2d
(
AbstractConv
):
""" Abstract Op for the forward convolution.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
"""
def
__init__
(
self
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
)):
super
(
AbstractConv2d
,
self
)
.
__init__
(
convdim
=
2
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
def
grad
(
self
,
inp
,
grads
):
bottom
,
weights
=
inp
top
,
=
grads
...
...
@@ -889,25 +1257,59 @@ class AbstractConv2d(BaseAbstractConv2d):
d_weights
=
weights
.
type
.
filter_variable
(
d_weights
)
return
d_bottom
,
d_weights
def
infer_shape
(
self
,
node
,
input_shapes
):
imshp
=
input_shapes
[
0
]
kshp
=
input_shapes
[
1
]
# replace symbolic shapes with known constant shapes
if
self
.
imshp
is
not
None
:
imshp
=
[
imshp
[
i
]
if
self
.
imshp
[
i
]
is
None
else
self
.
imshp
[
i
]
for
i
in
range
(
4
)]
if
self
.
kshp
is
not
None
:
kshp
=
[
kshp
[
i
]
if
self
.
kshp
[
i
]
is
None
else
self
.
kshp
[
i
]
for
i
in
range
(
4
)]
res
=
get_conv_output_shape
(
imshp
,
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_dilation
)
return
[
res
]
class
AbstractConv3d
(
AbstractConv
):
""" Abstract Op for the forward convolution.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
"""
def
__init__
(
self
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
,
1
)):
super
(
AbstractConv3d
,
self
)
.
__init__
(
convdim
=
3
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
class
AbstractConv2d_gradWeights
(
BaseAbstractConv2d
):
"""Gradient wrt. filters for `AbstractConv2d`.
Refer to :func:`BaseAbstractConv2d <theano.tensor.nnet.abstract_conv.BaseAbstractConv2d>`
def
grad
(
self
,
inp
,
grads
):
bottom
,
weights
=
inp
top
,
=
grads
d_bottom
=
AbstractConv3d_gradInputs
(
self
.
imshp
,
self
.
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_flip
,
self
.
filter_dilation
)(
weights
,
top
,
bottom
.
shape
[
-
3
:])
d_weights
=
AbstractConv3d_gradWeights
(
self
.
imshp
,
self
.
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_flip
,
self
.
filter_dilation
)(
bottom
,
top
,
weights
.
shape
[
-
3
:])
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_bottom
=
patternbroadcast
(
d_bottom
,
bottom
.
broadcastable
)
d_bottom
=
bottom
.
type
.
filter_variable
(
d_bottom
)
d_weights
=
patternbroadcast
(
d_weights
,
weights
.
broadcastable
)
d_weights
=
weights
.
type
.
filter_variable
(
d_weights
)
return
d_bottom
,
d_weights
class
AbstractConv_gradWeights
(
BaseAbstractConv
):
"""Gradient wrt. filters for `AbstractConv`.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
:note: You will not want to use this directly, but rely on
...
...
@@ -916,17 +1318,19 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
"""
def
__init__
(
self
,
convdim
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
)
,
subsample
=
None
,
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
)):
super
(
AbstractConv2d_gradWeights
,
self
)
.
__init__
(
imshp
,
kshp
,
border_mode
,
subsample
,
filter_flip
,
filter_dilation
)
filter_dilation
=
None
):
super
(
AbstractConv_gradWeights
,
self
)
.
__init__
(
convdim
=
convdim
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
# Update shape/height_width
def
make_node
(
self
,
img
,
topgrad
,
shape
):
...
...
@@ -939,15 +1343,14 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
broadcastable
=
topgrad
.
broadcastable
)
topgrad
=
gtype
.
filter_variable
(
topgrad
)
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be
4D tensor'
)
if
topgrad
.
type
.
ndim
!=
4
:
raise
TypeError
(
'topgrad must be
4D tensor'
)
if
img
.
type
.
ndim
!=
2
+
self
.
convdim
:
raise
TypeError
(
'img must be
%
dD tensor'
%
(
2
+
self
.
convdim
)
)
if
topgrad
.
type
.
ndim
!=
2
+
self
.
convdim
:
raise
TypeError
(
'topgrad must be
%
dD tensor'
%
(
2
+
self
.
convdim
)
)
shape
=
as_tensor_variable
(
shape
)
broadcastable
=
[
topgrad
.
broadcastable
[
1
],
img
.
broadcastable
[
1
],
False
,
False
]
img
.
broadcastable
[
1
]]
+
([
False
]
*
self
.
convdim
)
output
=
img
.
type
.
clone
(
broadcastable
=
broadcastable
)()
return
Apply
(
self
,
[
img
,
topgrad
,
shape
],
[
output
])
...
...
@@ -963,45 +1366,97 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
mode
in
(
'valid'
,
'full'
,
'half'
)):
raise
ValueError
(
'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a
pair
of'
'"valid", "full", "half", an integer or a
tuple
of'
' integers'
.
format
(
mode
))
dil_shape
=
((
shape
[
0
]
-
1
)
*
self
.
filter_dilation
[
0
]
+
1
,
(
shape
[
1
]
-
1
)
*
self
.
filter_dilation
[
1
]
+
1
)
dil_shape
=
tuple
((
shape
[
i
]
-
1
)
*
self
.
filter_dilation
[
i
]
+
1
for
i
in
range
(
self
.
convdim
)
)
if
mode
==
"full"
:
mode
=
(
dil_shape
[
0
]
-
1
,
dil_shape
[
1
]
-
1
)
mode
=
tuple
(
dil_shape
[
i
]
-
1
for
i
in
range
(
self
.
convdim
)
)
elif
mode
==
"half"
:
mode
=
(
dil_shape
[
0
]
//
2
,
dil_shape
[
1
]
//
2
)
mode
=
tuple
(
dil_shape
[
i
]
//
2
for
i
in
range
(
self
.
convdim
)
)
if
isinstance
(
mode
,
tuple
):
pad
_h
,
pad_w
=
map
(
int
,
mode
)
pad
=
tuple
(
int
(
mode
[
i
])
for
i
in
range
(
self
.
convdim
)
)
mode
=
"valid"
new_img
=
numpy
.
zeros
((
img
.
shape
[
0
],
img
.
shape
[
1
],
img
.
shape
[
2
]
+
2
*
pad_h
,
img
.
shape
[
3
]
+
2
*
pad_w
),
dtype
=
img
.
dtype
)
new_img
[:,
:,
pad_h
:
img
.
shape
[
2
]
+
pad_h
,
pad_w
:
img
.
shape
[
3
]
+
pad_w
]
=
img
new_img
=
numpy
.
zeros
((
img
.
shape
[
0
],
img
.
shape
[
1
])
+
tuple
(
img
.
shape
[
i
+
2
]
+
2
*
pad
[
i
]
for
i
in
range
(
self
.
convdim
)),
dtype
=
img
.
dtype
)
new_img
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
pad
[
i
],
img
.
shape
[
i
+
2
]
+
pad
[
i
])
for
i
in
range
(
self
.
convdim
))]
=
img
img
=
new_img
if
self
.
subsample
[
0
]
>
1
or
self
.
subsample
[
1
]
>
1
:
new_shape
=
(
topgrad
.
shape
[
0
],
topgrad
.
shape
[
1
],
img
.
shape
[
2
]
-
dil_shape
[
0
]
+
1
,
img
.
shape
[
3
]
-
dil_shape
[
1
]
+
1
)
if
any
(
self
.
subsample
[
i
]
>
1
for
i
in
range
(
self
.
convdim
))
:
new_shape
=
(
(
topgrad
.
shape
[
0
],
topgrad
.
shape
[
1
])
+
tuple
(
img
.
shape
[
i
+
2
]
-
dil_shape
[
i
]
+
1
for
i
in
range
(
self
.
convdim
))
)
new_topgrad
=
numpy
.
zeros
((
new_shape
),
dtype
=
topgrad
.
dtype
)
new_topgrad
[:,
:,
::
self
.
subsample
[
0
],
::
self
.
subsample
[
1
]]
=
topgrad
new_topgrad
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
None
,
None
,
self
.
subsample
[
i
])
for
i
in
range
(
self
.
convdim
))]
=
topgrad
topgrad
=
new_topgrad
topgrad
=
topgrad
.
transpose
(
1
,
0
,
2
,
3
)[:,
:,
::
-
1
,
::
-
1
]
img
=
img
.
transpose
(
1
,
0
,
2
,
3
)
kern
=
self
.
conv2d
(
img
,
topgrad
,
mode
=
"valid"
)
if
self
.
filter_dilation
[
0
]
>
1
or
self
.
filter_dilation
[
1
]
>
1
:
kern
=
kern
[:,
:,
::
self
.
filter_dilation
[
0
],
::
self
.
filter_dilation
[
1
]]
axes_order
=
(
1
,
0
)
+
tuple
(
range
(
2
,
self
.
convdim
+
2
))
flip_filters
=
((
slice
(
None
),
slice
(
None
))
+
(
slice
(
None
,
None
,
-
1
),)
*
self
.
convdim
)
topgrad
=
topgrad
.
transpose
(
axes_order
)[
flip_filters
]
img
=
img
.
transpose
(
axes_order
)
kern
=
self
.
conv
(
img
,
topgrad
,
mode
=
"valid"
)
if
any
(
self
.
filter_dilation
[
i
]
>
1
for
i
in
range
(
self
.
convdim
)):
kern
=
kern
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
None
,
None
,
self
.
filter_dilation
[
i
])
for
i
in
range
(
self
.
convdim
))]
if
self
.
filter_flip
:
kern
=
kern
.
transpose
(
1
,
0
,
2
,
3
)[:,
:,
::
-
1
,
::
-
1
]
kern
=
kern
.
transpose
(
axes_order
)[
flip_filters
]
else
:
kern
=
kern
.
transpose
(
1
,
0
,
2
,
3
)
kern
=
kern
.
transpose
(
axes_order
)
o
[
0
]
=
node
.
outputs
[
0
]
.
type
.
filter
(
kern
)
def
connection_pattern
(
self
,
node
):
return
[[
1
],
[
1
],
[
0
]]
# no connection to height, width
def
infer_shape
(
self
,
node
,
input_shapes
):
# We use self.kshp (that was passed when creating the Op) if possible,
# or fall back to the `shape` input of the node.
# TODO: when there is no subsampling, try to infer the kernel shape
# from the shapes of inputs.
imshp
=
input_shapes
[
0
]
topshp
=
input_shapes
[
1
]
kshp
=
self
.
kshp
[:]
if
self
.
kshp
is
not
None
else
[
None
]
*
(
2
+
self
.
convdim
)
fallback_kshp
=
([
topshp
[
1
],
imshp
[
1
]]
+
[
node
.
inputs
[
2
][
i
]
for
i
in
range
(
self
.
convdim
)])
kshp
=
[
fallback_kshp
[
i
]
if
kshp
[
i
]
is
None
else
kshp
[
i
]
for
i
in
range
(
2
+
self
.
convdim
)]
return
[
kshp
]
class
AbstractConv2d_gradWeights
(
AbstractConv_gradWeights
):
"""Gradient wrt. filters for `AbstractConv2d`.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
:note: You will not want to use this directly, but rely on
Theano's automatic differentiation or graph optimization to
use it as needed.
"""
def
__init__
(
self
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
)):
super
(
AbstractConv2d_gradWeights
,
self
)
.
__init__
(
convdim
=
2
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
def
grad
(
self
,
inp
,
grads
):
bottom
,
top
=
inp
[:
2
]
weights
,
=
grads
...
...
@@ -1031,26 +1486,64 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
d_height_width
=
(
theano
.
gradient
.
DisconnectedType
()(),)
return
(
d_bottom
,
d_top
)
+
d_height_width
def
connection_pattern
(
self
,
node
):
return
[[
1
],
[
1
],
[
0
]]
# no connection to height, width
def
infer_shape
(
self
,
node
,
input_shapes
):
# We use self.kshp (that was passed when creating the Op) if possible,
# or fall back to the `shape` input of the node.
# TODO: when there is no subsampling, try to infer the kernel shape
# from the shapes of inputs.
imshp
=
input_shapes
[
0
]
topshp
=
input_shapes
[
1
]
kshp
=
self
.
kshp
[:]
if
self
.
kshp
is
not
None
else
[
None
]
*
4
fallback_kshp
=
[
topshp
[
1
],
imshp
[
1
],
node
.
inputs
[
2
][
0
],
node
.
inputs
[
2
][
1
]]
kshp
=
[
fallback_kshp
[
i
]
if
kshp
[
i
]
is
None
else
kshp
[
i
]
for
i
in
range
(
4
)]
return
[
kshp
]
class
AbstractConv3d_gradWeights
(
AbstractConv_gradWeights
):
"""Gradient wrt. filters for `AbstractConv3d`.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
:note: You will not want to use this directly, but rely on
Theano's automatic differentiation or graph optimization to
use it as needed.
class
AbstractConv2d_gradInputs
(
BaseAbstractConv2d
):
"""Gradient wrt. inputs for `AbstractConv2d`.
Refer to :func:`BaseAbstractConv2d <theano.tensor.nnet.abstract_conv.BaseAbstractConv2d>`
"""
def
__init__
(
self
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
,
1
)):
super
(
AbstractConv3d_gradWeights
,
self
)
.
__init__
(
convdim
=
3
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
def
grad
(
self
,
inp
,
grads
):
bottom
,
top
=
inp
[:
2
]
weights
,
=
grads
d_bottom
=
AbstractConv3d_gradInputs
(
self
.
imshp
,
self
.
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_flip
,
self
.
filter_dilation
)(
weights
,
top
,
bottom
.
shape
[
-
3
:])
d_top
=
AbstractConv3d
(
self
.
imshp
,
self
.
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_flip
,
self
.
filter_dilation
)(
bottom
,
weights
)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_bottom
=
patternbroadcast
(
d_bottom
,
bottom
.
broadcastable
)
d_bottom
=
bottom
.
type
.
filter_variable
(
d_bottom
)
d_top
=
patternbroadcast
(
d_top
,
top
.
broadcastable
)
d_top
=
top
.
type
.
filter_variable
(
d_top
)
d_depth_height_width
=
(
theano
.
gradient
.
DisconnectedType
()(),)
return
(
d_bottom
,
d_top
)
+
d_depth_height_width
class
AbstractConv_gradInputs
(
BaseAbstractConv
):
"""Gradient wrt. inputs for `AbstractConv`.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
:note: You will not want to use this directly, but rely on
...
...
@@ -1060,17 +1553,19 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
"""
def
__init__
(
self
,
convdim
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
)
,
subsample
=
None
,
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
)):
super
(
AbstractConv2d_gradInputs
,
self
)
.
__init__
(
imshp
,
kshp
,
border_mode
,
subsample
,
filter_flip
,
filter_dilation
)
filter_dilation
=
None
):
super
(
AbstractConv_gradInputs
,
self
)
.
__init__
(
convdim
=
convdim
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
# Update shape/height_width
def
make_node
(
self
,
kern
,
topgrad
,
shape
):
...
...
@@ -1083,15 +1578,14 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
broadcastable
=
topgrad
.
broadcastable
)
topgrad
=
gtype
.
filter_variable
(
topgrad
)
if
kern
.
type
.
ndim
!=
4
:
raise
TypeError
(
'kern must be
4D tensor'
)
if
topgrad
.
type
.
ndim
!=
4
:
raise
TypeError
(
'topgrad must be
4D tensor'
)
if
kern
.
type
.
ndim
!=
2
+
self
.
convdim
:
raise
TypeError
(
'kern must be
%
dD tensor'
%
(
2
+
self
.
convdim
)
)
if
topgrad
.
type
.
ndim
!=
2
+
self
.
convdim
:
raise
TypeError
(
'topgrad must be
%
dD tensor'
%
(
2
+
self
.
convdim
)
)
shape
=
as_tensor_variable
(
shape
)
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
1
],
False
,
False
]
kern
.
type
.
broadcastable
[
1
]]
+
([
False
]
*
self
.
convdim
)
output
=
kern
.
type
.
clone
(
broadcastable
=
broadcastable
)()
return
Apply
(
self
,
[
kern
,
topgrad
,
shape
],
[
output
])
...
...
@@ -1106,35 +1600,86 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
mode
in
(
'valid'
,
'full'
,
'half'
)):
raise
ValueError
(
'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a
pair
of'
'"valid", "full", "half", an integer or a
tuple
of'
' integers'
.
format
(
mode
))
dil_kernshp
=
((
kern
.
shape
[
2
]
-
1
)
*
self
.
filter_dilation
[
0
]
+
1
,
(
kern
.
shape
[
3
]
-
1
)
*
self
.
filter_dilation
[
1
]
+
1
)
pad
_h
,
pad_w
=
0
,
0
dil_kernshp
=
tuple
((
kern
.
shape
[
i
+
2
]
-
1
)
*
self
.
filter_dilation
[
i
]
+
1
for
i
in
range
(
self
.
convdim
)
)
pad
=
(
0
,)
*
self
.
convdim
if
mode
==
"full"
:
pad
_h
,
pad_w
=
(
dil_kernshp
[
0
]
-
1
,
dil_kernshp
[
1
]
-
1
)
pad
=
tuple
(
dil_kernshp
[
i
]
-
1
for
i
in
range
(
self
.
convdim
)
)
elif
mode
==
"half"
:
pad
_h
,
pad_w
=
(
dil_kernshp
[
0
]
//
2
,
dil_kernshp
[
1
]
//
2
)
pad
=
tuple
(
dil_kernshp
[
i
]
//
2
for
i
in
range
(
self
.
convdim
)
)
elif
isinstance
(
mode
,
tuple
):
pad
_h
,
pad_w
=
map
(
int
,
self
.
border_mode
)
if
self
.
subsample
[
0
]
>
1
or
self
.
subsample
[
1
]
>
1
:
new_shape
=
(
topgrad
.
shape
[
0
],
topgrad
.
shape
[
1
],
shape
[
0
]
+
2
*
pad_h
-
dil_kernshp
[
0
]
+
1
,
shape
[
1
]
+
2
*
pad_w
-
dil_kernshp
[
1
]
+
1
)
pad
=
tuple
(
mode
[
i
]
for
i
in
range
(
self
.
convdim
)
)
if
any
(
self
.
subsample
[
i
]
>
1
for
i
in
range
(
self
.
convdim
))
:
new_shape
=
(
(
topgrad
.
shape
[
0
],
topgrad
.
shape
[
1
])
+
tuple
(
shape
[
i
]
+
2
*
pad
[
i
]
-
dil_kernshp
[
i
]
+
1
for
i
in
range
(
self
.
convdim
))
)
new_topgrad
=
numpy
.
zeros
((
new_shape
),
dtype
=
topgrad
.
dtype
)
new_topgrad
[:,
:,
::
self
.
subsample
[
0
],
::
self
.
subsample
[
1
]]
=
topgrad
new_topgrad
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
None
,
None
,
self
.
subsample
[
i
])
for
i
in
range
(
self
.
convdim
))]
=
topgrad
topgrad
=
new_topgrad
kern
=
kern
.
transpose
(
1
,
0
,
2
,
3
)
axes_order
=
(
1
,
0
)
+
tuple
(
range
(
2
,
self
.
convdim
+
2
))
flip_filters
=
((
slice
(
None
),
slice
(
None
))
+
(
slice
(
None
,
None
,
-
1
),)
*
self
.
convdim
)
kern
=
kern
.
transpose
(
axes_order
)
if
self
.
filter_flip
:
topgrad
=
topgrad
[
:,
:,
::
-
1
,
::
-
1
]
img
=
self
.
conv
2d
(
topgrad
,
kern
,
mode
=
"full"
,
dilation
=
self
.
filter_dilation
)
topgrad
=
topgrad
[
flip_filters
]
img
=
self
.
conv
(
topgrad
,
kern
,
mode
=
"full"
,
dilation
=
self
.
filter_dilation
)
if
self
.
filter_flip
:
img
=
img
[:,
:,
::
-
1
,
::
-
1
]
if
pad_h
>
0
or
pad_w
>
0
:
img
=
img
[:,
:,
pad_h
:
img
.
shape
[
2
]
-
pad_h
,
pad_w
:
img
.
shape
[
3
]
-
pad_w
]
img
=
img
[
flip_filters
]
if
any
(
p
>
0
for
p
in
pad
):
img
=
img
[(
slice
(
None
),
slice
(
None
))
+
tuple
(
slice
(
pad
[
i
],
img
.
shape
[
i
+
2
]
-
pad
[
i
])
for
i
in
range
(
self
.
convdim
))]
o
[
0
]
=
node
.
outputs
[
0
]
.
type
.
filter
(
img
)
def
connection_pattern
(
self
,
node
):
return
[[
1
],
[
1
],
[
0
]]
# no connection to height, width
def
infer_shape
(
self
,
node
,
input_shapes
):
# We use self.imshp (that was passed when creating the Op) if possible,
# or fall back to the `shape` input of the node.
# TODO: when there is no subsampling, try to infer the image shape
# from the shapes of inputs.
kshp
=
input_shapes
[
0
]
topshp
=
input_shapes
[
1
]
imshp
=
self
.
imshp
[:]
if
self
.
imshp
is
not
None
else
[
None
]
*
(
2
+
self
.
convdim
)
fallback_imshp
=
([
topshp
[
0
],
kshp
[
1
]]
+
[
node
.
inputs
[
2
][
i
]
for
i
in
range
(
self
.
convdim
)])
imshp
=
[
fallback_imshp
[
i
]
if
imshp
[
i
]
is
None
else
imshp
[
i
]
for
i
in
range
(
2
+
self
.
convdim
)]
return
[
imshp
]
class
AbstractConv2d_gradInputs
(
AbstractConv_gradInputs
):
"""Gradient wrt. inputs for `AbstractConv2d`.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
:note: You will not want to use this directly, but rely on
Theano's automatic differentiation or graph optimization to
use it as needed.
"""
def
__init__
(
self
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
)):
super
(
AbstractConv2d_gradInputs
,
self
)
.
__init__
(
convdim
=
2
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
def
grad
(
self
,
inp
,
grads
):
weights
,
top
=
inp
[:
2
]
bottom
,
=
grads
...
...
@@ -1162,19 +1707,55 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
d_height_width
=
(
theano
.
gradient
.
DisconnectedType
()(),)
return
(
d_weights
,
d_top
)
+
d_height_width
def
connection_pattern
(
self
,
node
):
return
[[
1
],
[
1
],
[
0
]]
# no connection to height, width
def
infer_shape
(
self
,
node
,
input_shapes
):
# We use self.imshp (that was passed when creating the Op) if possible,
# or fall back to the `shape` input of the node.
# TODO: when there is no subsampling, try to infer the image shape
# from the shapes of inputs.
kshp
=
input_shapes
[
0
]
topshp
=
input_shapes
[
1
]
imshp
=
self
.
imshp
[:]
if
self
.
imshp
is
not
None
else
[
None
]
*
4
fallback_imshp
=
[
topshp
[
0
],
kshp
[
1
],
node
.
inputs
[
2
][
0
],
node
.
inputs
[
2
][
1
]]
imshp
=
[
fallback_imshp
[
i
]
if
imshp
[
i
]
is
None
else
imshp
[
i
]
for
i
in
range
(
4
)]
return
[
imshp
]
class
AbstractConv3d_gradInputs
(
AbstractConv_gradInputs
):
"""Gradient wrt. inputs for `AbstractConv3d`.
Refer to :func:`BaseAbstractConv <theano.tensor.nnet.abstract_conv.BaseAbstractConv>`
for a more detailed documentation.
:note: You will not want to use this directly, but rely on
Theano's automatic differentiation or graph optimization to
use it as needed.
"""
def
__init__
(
self
,
imshp
=
None
,
kshp
=
None
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
filter_flip
=
True
,
filter_dilation
=
(
1
,
1
,
1
)):
super
(
AbstractConv3d_gradInputs
,
self
)
.
__init__
(
convdim
=
3
,
imshp
=
imshp
,
kshp
=
kshp
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
filter_dilation
=
filter_dilation
)
def
grad
(
self
,
inp
,
grads
):
weights
,
top
=
inp
[:
2
]
bottom
,
=
grads
d_weights
=
AbstractConv3d_gradWeights
(
self
.
imshp
,
self
.
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_flip
,
self
.
filter_dilation
)(
bottom
,
top
,
weights
.
shape
[
-
3
:])
d_top
=
AbstractConv3d
(
self
.
imshp
,
self
.
kshp
,
self
.
border_mode
,
self
.
subsample
,
self
.
filter_flip
,
self
.
filter_dilation
)(
bottom
,
weights
)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_weights
=
patternbroadcast
(
d_weights
,
weights
.
broadcastable
)
d_weights
=
weights
.
type
.
filter_variable
(
d_weights
)
d_top
=
patternbroadcast
(
d_top
,
top
.
broadcastable
)
d_top
=
top
.
type
.
filter_variable
(
d_top
)
d_depth_height_width
=
(
theano
.
gradient
.
DisconnectedType
()(),)
return
(
d_weights
,
d_top
)
+
d_depth_height_width
theano/tensor/nnet/opt.py
浏览文件 @
289c3bd4
...
...
@@ -18,6 +18,9 @@ from theano.tensor.nnet.blocksparse import (
from
theano.tensor.nnet.abstract_conv
import
(
AbstractConv2d
,
AbstractConv2d_gradWeights
,
AbstractConv2d_gradInputs
)
from
theano.tensor.nnet.abstract_conv
import
(
AbstractConv3d
,
AbstractConv3d_gradWeights
,
AbstractConv3d_gradInputs
)
from
theano.tensor.nnet.abstract_conv
import
get_conv_output_shape
from
theano.tensor.opt
import
register_specialize_device
from
theano.tensor
import
TensorType
...
...
@@ -25,6 +28,7 @@ from theano.tensor import opt
# Cpu implementation
from
theano.tensor.nnet.conv
import
conv2d
,
ConvOp
from
theano.tensor.nnet.Conv3D
import
conv3D
from
theano.tensor.nnet.ConvGrad3D
import
convGrad3D
from
theano.tensor.nnet.ConvTransp3D
import
convTransp3D
...
...
@@ -159,6 +163,37 @@ def local_conv2d_cpu(node):
return
[
rval
]
@local_optimizer
([
AbstractConv3d
])
def
local_conv3d_cpu
(
node
):
if
not
isinstance
(
node
.
op
,
AbstractConv3d
):
return
None
img
,
kern
=
node
.
inputs
if
((
not
isinstance
(
img
.
type
,
TensorType
)
or
not
isinstance
(
kern
.
type
,
TensorType
))):
return
None
if
node
.
op
.
border_mode
not
in
[
'valid'
,
(
0
,
0
,
0
)]:
return
None
if
node
.
op
.
filter_dilation
!=
(
1
,
1
,
1
):
return
None
bias
=
theano
.
tensor
.
zeros_like
(
kern
[:,
0
,
0
,
0
,
0
])
# need to flip the kernel if necessary (conv3D does not flip)
if
node
.
op
.
filter_flip
:
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
# conv3D expects shape (batch, row, column, time, channel)
img
=
img
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
kern
=
kern
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
rval
=
conv3D
(
img
,
kern
,
bias
,
node
.
op
.
subsample
)
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
rval
=
rval
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
return
[
rval
]
@local_optimizer
([
AbstractConv2d_gradWeights
])
def
local_conv2d_gradweight_cpu
(
node
):
if
not
isinstance
(
node
.
op
,
AbstractConv2d_gradWeights
):
...
...
@@ -277,6 +312,39 @@ def local_conv2d_gradweight_cpu(node):
return
[
res
]
@local_optimizer
([
AbstractConv3d_gradWeights
])
def
local_conv3d_gradweight_cpu
(
node
):
if
not
isinstance
(
node
.
op
,
AbstractConv3d_gradWeights
):
return
None
img
,
topgrad
,
shape
=
node
.
inputs
if
((
not
isinstance
(
img
.
type
,
TensorType
)
or
not
isinstance
(
topgrad
.
type
,
TensorType
))):
return
None
if
node
.
op
.
border_mode
not
in
[
'valid'
,
(
0
,
0
,
0
)]:
return
None
if
node
.
op
.
filter_dilation
!=
(
1
,
1
,
1
):
return
None
# conv3D expects shape (batch, row, column, time, channel)
img
=
img
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
topgrad
=
topgrad
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
W_shape
=
(
topgrad
.
shape
[
4
],
shape
[
0
],
shape
[
1
],
shape
[
2
],
img
.
shape
[
4
])
rval
=
convGrad3D
(
img
,
node
.
op
.
subsample
,
W_shape
,
topgrad
)
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
rval
=
rval
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
# need to flip the kernel if necessary (conv3D does not flip)
if
node
.
op
.
filter_flip
:
rval
=
rval
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
rval
=
theano
.
tensor
.
patternbroadcast
(
rval
,
node
.
outputs
[
0
]
.
broadcastable
)
return
[
rval
]
@local_optimizer
([
AbstractConv2d_gradInputs
])
def
local_conv2d_gradinputs_cpu
(
node
):
if
not
isinstance
(
node
.
op
,
AbstractConv2d_gradInputs
):
...
...
@@ -366,6 +434,38 @@ def local_conv2d_gradinputs_cpu(node):
return
[
din
]
@local_optimizer
([
AbstractConv3d_gradInputs
])
def
local_conv3d_gradinputs_cpu
(
node
):
if
not
isinstance
(
node
.
op
,
AbstractConv3d_gradInputs
):
return
None
kern
,
topgrad
,
shape
=
node
.
inputs
if
((
not
isinstance
(
kern
.
type
,
TensorType
)
or
not
isinstance
(
topgrad
.
type
,
TensorType
))):
return
None
if
node
.
op
.
border_mode
not
in
[
'valid'
,
(
0
,
0
,
0
)]:
return
None
if
node
.
op
.
filter_dilation
!=
(
1
,
1
,
1
):
return
None
# need to flip the kernel if necessary (conv3D does not flip)
if
node
.
op
.
filter_flip
:
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
# conv3D expects shape (batch, row, column, time, channel)
kern
=
kern
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
topgrad
=
topgrad
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
bias
=
theano
.
tensor
.
zeros_like
(
kern
[
0
,
0
,
0
,
0
,
:])
rval
=
convTransp3D
(
kern
,
bias
,
node
.
op
.
subsample
,
topgrad
,
shape
)
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
rval
=
rval
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
rval
=
theano
.
tensor
.
patternbroadcast
(
rval
,
node
.
outputs
[
0
]
.
broadcastable
)
return
[
rval
]
# Register Cpu Optmization
conv_groupopt
=
theano
.
gof
.
optdb
.
LocalGroupDB
()
conv_groupopt
.
__name__
=
"conv_opts"
...
...
@@ -390,16 +490,30 @@ conv_groupopt.register('local_conv2d_gradweight_cpu',
conv_groupopt
.
register
(
'local_conv2d_gradinputs_cpu'
,
local_conv2d_gradinputs_cpu
,
40
,
'fast_compile'
,
'fast_run'
)
conv_groupopt
.
register
(
'local_conv3d_cpu'
,
local_conv3d_cpu
,
40
,
'fast_compile'
,
'fast_run'
)
conv_groupopt
.
register
(
'local_conv3d_gradweight_cpu'
,
local_conv3d_gradweight_cpu
,
40
,
'fast_compile'
,
'fast_run'
)
conv_groupopt
.
register
(
'local_conv3d_gradinputs_cpu'
,
local_conv3d_gradinputs_cpu
,
40
,
'fast_compile'
,
'fast_run'
)
# Verify that no AbstractConv are present in the graph
@local_optimizer
([
AbstractConv2d
,
AbstractConv2d_gradWeights
,
AbstractConv2d_gradInputs
])
AbstractConv2d_gradInputs
,
AbstractConv3d
,
AbstractConv3d_gradWeights
,
AbstractConv3d_gradInputs
])
def
local_abstractconv_check
(
node
):
if
isinstance
(
node
.
op
,
(
AbstractConv2d
,
AbstractConv2d_gradWeights
,
AbstractConv2d_gradInputs
)):
AbstractConv2d_gradInputs
,
AbstractConv3d
,
AbstractConv3d_gradWeights
,
AbstractConv3d_gradInputs
)):
raise
AssertionError
(
'
%
s Theano optimization failed: there is no implementation '
'available supporting the requested options. Did you exclude '
...
...
theano/tensor/nnet/tests/test_abstract_conv.py
浏览文件 @
289c3bd4
...
...
@@ -20,13 +20,14 @@ from theano.tensor.nnet.abstract_conv import bilinear_upsampling
from
theano.tensor.nnet.conv
import
ConvOp
from
theano.tensor.nnet.corr
import
(
CorrMM
,
CorrMM_gradWeights
,
CorrMM_gradInputs
)
from
theano.tensor.nnet.Conv3D
import
Conv3D
from
theano.tensor.nnet.ConvGrad3D
import
ConvGrad3D
from
theano.tensor.nnet.ConvTransp3D
import
ConvTransp3D
def
conv_corr
(
inputs
,
filters
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
)):
def
conv
2d
_corr
(
inputs
,
filters
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
)):
if
conv_mode
==
'conv'
:
filters
=
filters
[:,
:,
::
-
1
,
::
-
1
]
return
corr
.
CorrMM
(
border_mode
,
...
...
@@ -34,9 +35,9 @@ def conv_corr(inputs, filters, border_mode="valid",
filter_dilation
)(
inputs
,
filters
)
def
conv_corr_gw
(
inputs
,
topgrad
,
filters_shape
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
)):
def
conv
2d
_corr_gw
(
inputs
,
topgrad
,
filters_shape
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
)):
rval
=
corr
.
CorrMM_gradWeights
(
border_mode
,
subsample
,
filter_dilation
)(
inputs
,
topgrad
,
...
...
@@ -46,9 +47,9 @@ def conv_corr_gw(inputs, topgrad, filters_shape,
return
rval
def
conv_corr_gi
(
filters
,
topgrad
,
inputs_shape
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
)):
def
conv
2d
_corr_gi
(
filters
,
topgrad
,
inputs_shape
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
)):
if
conv_mode
==
'conv'
:
filters
=
filters
[:,
:,
::
-
1
,
::
-
1
]
return
corr
.
CorrMM_gradInputs
(
border_mode
,
...
...
@@ -58,6 +59,126 @@ def conv_corr_gi(filters, topgrad, inputs_shape,
inputs_shape
[
2
:])
def
_padding_3d_inputs_to_valid
(
inputs
,
filters_shape
,
border_mode
=
'valid'
):
# pad inputs to have valid convolution
if
border_mode
==
'valid'
:
border_mode
=
(
0
,
0
,
0
)
elif
border_mode
==
'full'
:
border_mode
=
tuple
(
f
-
1
for
f
in
filters_shape
[
2
:])
elif
not
isinstance
(
border_mode
,
tuple
):
raise
ValueError
(
'Unsupported border mode'
,
border_mode
)
if
border_mode
==
(
0
,
0
,
0
):
return
inputs
else
:
# add padding here, because Conv3D only supports valid convolution
i_shp
=
inputs
.
shape
pad
=
border_mode
inputs_padded
=
tensor
.
zeros
(
dtype
=
inputs
.
dtype
,
shape
=
(
i_shp
[
0
],
i_shp
[
1
],
i_shp
[
2
]
+
2
*
pad
[
0
],
i_shp
[
3
]
+
2
*
pad
[
1
],
i_shp
[
4
]
+
2
*
pad
[
2
]))
inputs_padded
=
tensor
.
set_subtensor
(
inputs_padded
[:,
:,
pad
[
0
]:
i_shp
[
2
]
+
pad
[
0
],
pad
[
1
]:
i_shp
[
3
]
+
pad
[
1
],
pad
[
2
]:
i_shp
[
4
]
+
pad
[
2
]],
inputs
)
return
inputs_padded
def
_padding_3d_inputs_shape_to_valid
(
inputs_shape
,
filters_shape
,
border_mode
=
'valid'
):
# pad inputs_shape to have valid convolution
if
border_mode
==
'valid'
:
border_mode
=
(
0
,
0
,
0
)
elif
border_mode
==
'full'
:
border_mode
=
tuple
(
f
-
1
for
f
in
filters_shape
[
2
:])
elif
not
isinstance
(
border_mode
,
tuple
):
raise
ValueError
(
'Unsupported border mode'
,
border_mode
)
return
(
inputs_shape
[
0
],
inputs_shape
[
1
],
inputs_shape
[
2
]
+
2
*
border_mode
[
0
],
inputs_shape
[
3
]
+
2
*
border_mode
[
1
],
inputs_shape
[
4
]
+
2
*
border_mode
[
2
])
def
_crop_3d_padded_inputs
(
inputs
,
filters_shape
,
border_mode
=
'valid'
):
# crop border from padded input
if
border_mode
==
'valid'
:
border_mode
=
(
0
,
0
,
0
)
elif
border_mode
==
'full'
:
border_mode
=
tuple
(
f
-
1
for
f
in
filters_shape
[
2
:])
elif
not
isinstance
(
border_mode
,
tuple
):
raise
ValueError
(
'Unsupported border mode'
,
border_mode
)
if
border_mode
==
(
0
,
0
,
0
):
return
inputs
else
:
# crop
i_shp
=
inputs
.
shape
pad
=
border_mode
return
inputs
[:,
:,
pad
[
0
]:
i_shp
[
2
]
-
pad
[
0
],
pad
[
1
]:
i_shp
[
3
]
-
pad
[
1
],
pad
[
2
]:
i_shp
[
4
]
-
pad
[
2
]]
def
conv3d_corr
(
inputs
,
filters
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
,
1
)):
assert
filter_dilation
==
(
1
,
1
,
1
)
inputs
=
_padding_3d_inputs_to_valid
(
inputs
,
filters
.
shape
,
border_mode
)
if
conv_mode
==
'conv'
:
filters
=
filters
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
bias
=
tensor
.
zeros_like
(
filters
[:,
0
,
0
,
0
,
0
])
# Conv3D expects shape (batch, row, column, time, channel)
inputs
=
inputs
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
filters
=
filters
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
rval
=
Conv3D
()(
inputs
,
filters
,
bias
,
subsample
)
return
rval
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
def
conv3d_corr_gw
(
inputs
,
topgrad
,
filters_shape
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
,
1
)):
assert
filter_dilation
==
(
1
,
1
,
1
)
inputs
=
_padding_3d_inputs_to_valid
(
inputs
,
filters_shape
,
border_mode
)
# Conv3D expects shape (batch, row, column, time, channel)
inputs
=
inputs
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
topgrad
=
topgrad
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
filters_shape
=
tuple
(
filters_shape
[
i
]
for
i
in
(
0
,
2
,
3
,
4
,
1
))
rval
=
ConvGrad3D
()(
inputs
,
subsample
,
filters_shape
,
topgrad
)
rval
=
rval
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
if
conv_mode
==
'conv'
:
rval
=
rval
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
return
rval
def
conv3d_corr_gi
(
filters
,
topgrad
,
inputs_shape
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
conv_mode
=
'conv'
,
filter_dilation
=
(
1
,
1
,
1
)):
assert
filter_dilation
==
(
1
,
1
,
1
)
inputs_shape
=
_padding_3d_inputs_shape_to_valid
(
inputs_shape
,
filters
.
shape
,
border_mode
)
if
conv_mode
==
'conv'
:
filters
=
filters
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
# Conv3D expects shape (batch, row, column, time, channel)
filters_shuffled
=
filters
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
topgrad_shuffled
=
topgrad
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
inputs_shape_shuffled
=
tuple
(
inputs_shape
[
i
]
for
i
in
(
0
,
2
,
3
,
4
,
1
))
bias
=
tensor
.
zeros_like
(
filters
[
0
,
:,
0
,
0
,
0
])
rval
=
ConvTransp3D
()(
filters_shuffled
,
bias
,
subsample
,
topgrad_shuffled
,
inputs_shape_shuffled
[
1
:
4
])
rval
=
rval
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
return
_crop_3d_padded_inputs
(
rval
,
filters
.
shape
,
border_mode
)
class
TestGetConvOutShape
(
unittest
.
TestCase
):
def
test_basic
(
self
):
image_shape
,
kernel_shape
=
(
3
,
2
,
12
,
9
),
(
4
,
2
,
5
,
6
)
...
...
@@ -95,34 +216,18 @@ class TestGetConvOutShape(unittest.TestCase):
self
.
assertTrue
(
test3_params
==
(
3
,
4
,
20
,
7
,
10
))
self
.
assertTrue
(
test4_params
==
(
3
,
4
,
6
,
4
,
10
))
class
BaseTestConv2d
:
@classmethod
def
setup_class
(
cls
):
if
theano
.
config
.
blas
.
ldflags
==
''
:
raise
SkipTest
(
"BLAS required for reference"
)
cls
.
inputs_shapes
=
[(
8
,
1
,
6
,
6
),
(
8
,
1
,
8
,
8
),
(
2
,
1
,
7
,
7
),
(
6
,
1
,
10
,
11
),
(
2
,
1
,
6
,
5
),
(
1
,
5
,
9
,
9
)]
cls
.
filters_shapes
=
[(
5
,
1
,
2
,
2
),
(
4
,
1
,
3
,
3
),
(
2
,
1
,
3
,
3
),
(
1
,
1
,
2
,
3
),
(
4
,
1
,
1
,
3
),
(
4
,
5
,
3
,
2
)]
cls
.
subsamples
=
[(
1
,
1
),
(
2
,
2
),
(
2
,
4
)]
cls
.
filters_dilations
=
[(
1
,
1
),
(
1
,
2
),
(
2
,
1
)]
cls
.
border_modes
=
[
"valid"
,
"half"
,
"full"
,
(
0
,
0
),
(
1
,
1
),
(
5
,
5
),
(
5
,
2
)]
cls
.
filter_flip
=
[
True
,
False
]
cls
.
provide_shape
=
[
True
,
False
]
cls
.
shared
=
staticmethod
(
theano
.
compile
.
shared
)
class
BaseTestConv
(
object
):
def
get_output_shape
(
self
,
inputs_shape
,
filters_shape
,
subsample
,
border_mode
,
filter_dilation
):
dil_filters
=
((
filters_shape
[
2
]
-
1
)
*
filter_dilation
[
0
]
+
1
,
(
filters_shape
[
3
]
-
1
)
*
filter_dilation
[
1
]
+
1
)
dil_filters
=
tuple
((
s
-
1
)
*
d
+
1
for
s
,
d
in
zip
(
filters_shape
[
2
:]
,
filter_dilation
)
)
if
border_mode
==
"valid"
:
border_mode
=
(
0
,
0
)
border_mode
=
(
0
,
)
*
(
len
(
inputs_shape
)
-
2
)
if
border_mode
==
"half"
:
border_mode
=
(
dil_filters
[
0
]
//
2
,
dil_filters
[
1
]
//
2
)
border_mode
=
tuple
(
d
//
2
for
d
in
dil_filters
)
if
border_mode
==
"full"
:
border_mode
=
(
dil_filters
[
0
]
-
1
,
dil_filters
[
1
]
-
1
)
border_mode
=
tuple
(
d
-
1
for
d
in
dil_filters
)
batch_size
=
inputs_shape
[
0
]
num_filters
=
filters_shape
[
0
]
return
((
batch_size
,
num_filters
,)
+
...
...
@@ -133,14 +238,24 @@ class BaseTestConv2d:
subsample
,
border_mode
,
filter_dilation
)))
def
run_fwd
(
self
,
inputs_shape
,
filters_shape
,
ref
=
conv_corr
,
subsample
=
(
1
,
1
),
verify_grad
=
True
,
mode
=
None
,
def
run_fwd
(
self
,
inputs_shape
,
filters_shape
,
conv_fn
,
conv_op
,
ref
,
subsample
=
None
,
verify_grad
=
True
,
mode
=
None
,
border_mode
=
'valid'
,
filter_flip
=
True
,
provide_shape
=
False
,
target_op
=
None
,
check_trace
=
False
,
filter_dilation
=
(
1
,
1
)):
check_trace
=
False
,
filter_dilation
=
None
):
if
subsample
is
None
:
subsample
=
(
1
,)
*
(
len
(
inputs_shape
)
-
2
)
if
filter_dilation
is
None
:
filter_dilation
=
(
1
,)
*
(
len
(
inputs_shape
)
-
2
)
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
filters_val
=
numpy
.
random
.
random
(
filters_shape
)
.
astype
(
'float32'
)
# scale down values to prevent rounding errors
inputs_val
/=
10
filters_val
/=
10
inputs
=
self
.
shared
(
inputs_val
)
filters
=
self
.
shared
(
filters_val
)
...
...
@@ -160,13 +275,13 @@ class BaseTestConv2d:
subsample
=
subsample
,
conv_mode
=
conv_mode
,
filter_dilation
=
filter_dilation
)
c
=
conv
.
conv2d
(
inputs
,
filters
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
input_shape
=
imshp
,
filter_shape
=
kshp
,
filter_dilation
=
filter_dilation
)
c
=
conv
_fn
(
inputs
,
filters
,
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
input_shape
=
imshp
,
filter_shape
=
kshp
,
filter_dilation
=
filter_dilation
)
f_ref
=
theano
.
function
([],
c_ref
,
mode
=
'FAST_RUN'
)
f
=
theano
.
function
([],
c
,
mode
=
mode
)
...
...
@@ -181,19 +296,24 @@ class BaseTestConv2d:
res
=
numpy
.
array
(
f
())
utt
.
assert_allclose
(
res_ref
,
res
)
if
verify_grad
:
utt
.
verify_grad
(
conv
.
AbstractConv2d
(
border_mode
=
border_mode
,
imshp
=
imshp
,
kshp
=
kshp
,
subsample
=
subsample
,
filter_dilation
=
filter_dilation
),
utt
.
verify_grad
(
conv
_op
(
border_mode
=
border_mode
,
imshp
=
imshp
,
kshp
=
kshp
,
subsample
=
subsample
,
filter_dilation
=
filter_dilation
),
[
inputs_val
,
filters_val
],
mode
=
mode
)
def
run_gradweight
(
self
,
inputs_shape
,
filters_shape
,
output_shape
,
ref
=
conv_corr_gw
,
subsample
=
(
1
,
1
)
,
gradWeights_fn
,
ref
,
subsample
=
None
,
filter_flip
=
True
,
verify_grad
=
True
,
mode
=
None
,
border_mode
=
'valid'
,
provide_shape
=
False
,
target_op
=
None
,
check_trace
=
False
,
filter_dilation
=
(
1
,
1
)):
filter_dilation
=
None
):
if
subsample
is
None
:
subsample
=
(
1
,)
*
(
len
(
inputs_shape
)
-
2
)
if
filter_dilation
is
None
:
filter_dilation
=
(
1
,)
*
(
len
(
inputs_shape
)
-
2
)
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
output_val
=
numpy
.
random
.
random
(
output_shape
)
.
astype
(
'float32'
)
...
...
@@ -210,12 +330,12 @@ class BaseTestConv2d:
conv_mode
=
'conv'
else
:
conv_mode
=
'cross'
c
=
conv
.
AbstractConv2d_gradWeights
(
border_mode
=
border_mode
,
filter_flip
=
filter_flip
,
subsample
=
subsample
,
imshp
=
imshp
,
kshp
=
kshp
,
filter_dilation
=
filter_dilation
)
c
=
c
(
inputs
,
output
,
filters_shape
[
-
2
:])
c
=
gradWeights_fn
(
border_mode
=
border_mode
,
filter_flip
=
filter_flip
,
subsample
=
subsample
,
imshp
=
imshp
,
kshp
=
kshp
,
filter_dilation
=
filter_dilation
)
c
=
c
(
inputs
,
output
,
filters_shape
[
2
:])
c_ref
=
ref
(
inputs
,
output
,
filters_shape
,
border_mode
=
border_mode
,
...
...
@@ -235,22 +355,28 @@ class BaseTestConv2d:
res
=
numpy
.
array
(
f
())
utt
.
assert_allclose
(
res_ref
,
res
)
def
abstract_conv
2d
_gradweight
(
inputs_val
,
output_val
):
conv_op
=
conv
.
AbstractConv2d_gradWeights
(
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_dilation
=
filter_dilation
)
return
conv_op
(
inputs_val
,
output_val
,
filters_shape
[
-
2
:])
def
abstract_conv_gradweight
(
inputs_val
,
output_val
):
conv_op
=
gradWeights_fn
(
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_dilation
=
filter_dilation
)
return
conv_op
(
inputs_val
,
output_val
,
filters_shape
[
2
:])
if
verify_grad
:
utt
.
verify_grad
(
abstract_conv
2d
_gradweight
,
utt
.
verify_grad
(
abstract_conv_gradweight
,
[
inputs_val
,
output_val
],
mode
=
mode
,
eps
=
1
)
def
run_gradinput
(
self
,
inputs_shape
,
filters_shape
,
output_shape
,
ref
=
conv_corr_gi
,
subsample
=
(
1
,
1
),
filter_flip
=
True
,
gradInputs_fn
,
ref
,
subsample
=
None
,
filter_flip
=
True
,
verify_grad
=
True
,
mode
=
None
,
border_mode
=
'valid'
,
provide_shape
=
False
,
target_op
=
None
,
check_trace
=
False
,
filter_dilation
=
(
1
,
1
)):
check_trace
=
False
,
filter_dilation
=
None
):
if
subsample
is
None
:
subsample
=
(
1
,)
*
(
len
(
inputs_shape
)
-
2
)
if
filter_dilation
is
None
:
filter_dilation
=
(
1
,)
*
(
len
(
inputs_shape
)
-
2
)
output_val
=
numpy
.
random
.
random
(
output_shape
)
.
astype
(
'float32'
)
filters_val
=
numpy
.
random
.
random
(
filters_shape
)
.
astype
(
'float32'
)
output
=
self
.
shared
(
output_val
)
...
...
@@ -266,12 +392,12 @@ class BaseTestConv2d:
conv_mode
=
'conv'
else
:
conv_mode
=
'cross'
c
=
conv
.
AbstractConv2d_gradInputs
(
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
imshp
=
imshp
,
kshp
=
kshp
,
filter_dilation
=
filter_dilation
)
c
=
c
(
filters
,
output
,
inputs_shape
[
-
2
:])
c
=
gradInputs_fn
(
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_flip
=
filter_flip
,
imshp
=
imshp
,
kshp
=
kshp
,
filter_dilation
=
filter_dilation
)
c
=
c
(
filters
,
output
,
inputs_shape
[
2
:])
c_ref
=
ref
(
filters
,
output
,
inputs_shape
,
border_mode
=
border_mode
,
subsample
=
subsample
,
conv_mode
=
conv_mode
,
filter_dilation
=
filter_dilation
)
...
...
@@ -288,24 +414,24 @@ class BaseTestConv2d:
res
=
numpy
.
array
(
f
())
utt
.
assert_allclose
(
res_ref
,
res
)
def
abstract_conv
2d
_gradinputs
(
filters_val
,
output_val
):
conv_op
=
conv
.
AbstractConv2d_gradInputs
(
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_dilation
=
filter_dilation
)
return
conv_op
(
filters_val
,
output_val
,
inputs_shape
[
-
2
:])
def
abstract_conv_gradinputs
(
filters_val
,
output_val
):
conv_op
=
gradInputs_fn
(
border_mode
=
border_mode
,
subsample
=
subsample
,
filter_dilation
=
filter_dilation
)
return
conv_op
(
filters_val
,
output_val
,
inputs_shape
[
2
:])
if
verify_grad
:
utt
.
verify_grad
(
abstract_conv
2d
_gradinputs
,
utt
.
verify_grad
(
abstract_conv_gradinputs
,
[
filters_val
,
output_val
],
mode
=
mode
,
eps
=
1
)
def
test_all
(
self
):
if
type
(
self
)
is
BaseTestConv
2d
:
if
type
(
self
)
is
BaseTestConv
:
raise
SkipTest
(
"base class"
)
ds
=
[
1
,
1
]
db
=
(
0
,
0
)
dflip
=
True
in
self
.
filter_flip
dprovide_shape
=
True
in
self
.
provide_shape
ds
=
self
.
default_subsamples
db
=
self
.
default_border_mode
dflip
=
self
.
default_
filter_flip
dprovide_shape
=
self
.
default_
provide_shape
for
(
i
,
f
)
in
zip
(
self
.
inputs_shapes
,
self
.
filters_shapes
):
for
provide_shape
in
self
.
provide_shape
:
yield
(
self
.
tcase
,
i
,
f
,
ds
,
db
,
dflip
,
provide_shape
)
...
...
@@ -318,6 +444,57 @@ class BaseTestConv2d:
yield
(
self
.
tcase
,
i
,
f
,
ds
,
db
,
flip
,
dprovide_shape
)
class
BaseTestConv2d
(
BaseTestConv
):
@classmethod
def
setup_class
(
cls
):
if
theano
.
config
.
blas
.
ldflags
==
''
:
raise
SkipTest
(
"BLAS required for reference"
)
cls
.
inputs_shapes
=
[(
8
,
1
,
6
,
6
),
(
8
,
1
,
8
,
8
),
(
2
,
1
,
7
,
7
),
(
6
,
1
,
10
,
11
),
(
2
,
1
,
6
,
5
),
(
1
,
5
,
9
,
9
)]
cls
.
filters_shapes
=
[(
5
,
1
,
2
,
2
),
(
4
,
1
,
3
,
3
),
(
2
,
1
,
3
,
3
),
(
1
,
1
,
2
,
3
),
(
4
,
1
,
1
,
3
),
(
4
,
5
,
3
,
2
)]
cls
.
subsamples
=
[(
1
,
1
),
(
2
,
2
),
(
2
,
4
)]
cls
.
default_subsamples
=
(
1
,
1
)
cls
.
filters_dilations
=
[(
1
,
1
),
(
1
,
2
),
(
2
,
1
)]
cls
.
border_modes
=
[
"valid"
,
"half"
,
"full"
,
(
0
,
0
),
(
1
,
1
),
(
5
,
5
),
(
5
,
2
)]
cls
.
default_border_mode
=
(
0
,
0
)
cls
.
filter_flip
=
[
True
,
False
]
cls
.
default_filter_flip
=
True
cls
.
provide_shape
=
[
True
,
False
]
cls
.
default_provide_shape
=
True
cls
.
shared
=
staticmethod
(
theano
.
compile
.
shared
)
def
run_fwd
(
self
,
inputs_shape
,
filters_shape
,
conv_fn
=
conv
.
conv2d
,
conv_op
=
conv
.
AbstractConv2d
,
ref
=
conv2d_corr
,
**
kwargs
):
super
(
BaseTestConv2d
,
self
)
.
run_fwd
(
inputs_shape
=
inputs_shape
,
filters_shape
=
filters_shape
,
conv_fn
=
conv_fn
,
conv_op
=
conv_op
,
ref
=
ref
,
**
kwargs
)
def
run_gradweight
(
self
,
inputs_shape
,
filters_shape
,
output_shape
,
gradWeights_fn
=
conv
.
AbstractConv2d_gradWeights
,
ref
=
conv2d_corr_gw
,
**
kwargs
):
super
(
BaseTestConv2d
,
self
)
.
run_gradweight
(
inputs_shape
=
inputs_shape
,
filters_shape
=
filters_shape
,
output_shape
=
output_shape
,
gradWeights_fn
=
gradWeights_fn
,
ref
=
ref
,
**
kwargs
)
def
run_gradinput
(
self
,
inputs_shape
,
filters_shape
,
output_shape
,
gradInputs_fn
=
conv
.
AbstractConv2d_gradInputs
,
ref
=
conv2d_corr_gi
,
**
kwargs
):
super
(
BaseTestConv2d
,
self
)
.
run_gradinput
(
inputs_shape
=
inputs_shape
,
filters_shape
=
filters_shape
,
output_shape
=
output_shape
,
gradInputs_fn
=
gradInputs_fn
,
ref
=
ref
,
**
kwargs
)
class
TestCorrConv2d
(
BaseTestConv2d
):
@classmethod
def
setup_class
(
cls
):
...
...
@@ -500,6 +677,192 @@ class TestCpuConv2d(BaseTestConv2d):
filter_dilation
=
fd
)
class
BaseTestConv3d
(
BaseTestConv
):
@classmethod
def
setup_class
(
cls
):
if
theano
.
config
.
blas
.
ldflags
==
''
:
raise
SkipTest
(
"BLAS required for reference"
)
cls
.
inputs_shapes
=
[(
2
,
1
,
6
,
6
,
6
),
(
2
,
2
,
7
,
5
,
6
)]
cls
.
filters_shapes
=
[(
3
,
1
,
2
,
2
,
2
),
(
1
,
2
,
2
,
3
,
1
)]
cls
.
subsamples
=
[(
1
,
1
,
1
),
(
2
,
2
,
2
),
(
1
,
2
,
3
)]
cls
.
default_subsamples
=
(
1
,
1
,
1
)
cls
.
filters_dilations
=
[(
1
,
1
,
1
),
(
1
,
2
,
1
),
(
2
,
1
,
2
)]
cls
.
border_modes
=
[
"valid"
,
"full"
,
(
0
,
0
,
0
),
(
2
,
2
,
3
)]
cls
.
default_border_mode
=
(
0
,
0
,
0
)
cls
.
filter_flip
=
[
True
,
False
]
cls
.
default_filter_flip
=
True
cls
.
provide_shape
=
[
True
,
False
]
cls
.
default_provide_shape
=
True
cls
.
shared
=
staticmethod
(
theano
.
compile
.
shared
)
def
run_fwd
(
self
,
inputs_shape
,
filters_shape
,
conv_fn
=
conv
.
conv3d
,
conv_op
=
conv
.
AbstractConv3d
,
ref
=
conv3d_corr
,
**
kwargs
):
super
(
BaseTestConv3d
,
self
)
.
run_fwd
(
inputs_shape
=
inputs_shape
,
filters_shape
=
filters_shape
,
conv_fn
=
conv_fn
,
conv_op
=
conv_op
,
ref
=
ref
,
**
kwargs
)
def
run_gradweight
(
self
,
inputs_shape
,
filters_shape
,
output_shape
,
gradWeights_fn
=
conv
.
AbstractConv3d_gradWeights
,
ref
=
conv3d_corr_gw
,
**
kwargs
):
super
(
BaseTestConv3d
,
self
)
.
run_gradweight
(
inputs_shape
=
inputs_shape
,
filters_shape
=
filters_shape
,
output_shape
=
output_shape
,
gradWeights_fn
=
gradWeights_fn
,
ref
=
ref
,
**
kwargs
)
def
run_gradinput
(
self
,
inputs_shape
,
filters_shape
,
output_shape
,
gradInputs_fn
=
conv
.
AbstractConv3d_gradInputs
,
ref
=
conv3d_corr_gi
,
**
kwargs
):
super
(
BaseTestConv3d
,
self
)
.
run_gradinput
(
inputs_shape
=
inputs_shape
,
filters_shape
=
filters_shape
,
output_shape
=
output_shape
,
gradInputs_fn
=
gradInputs_fn
,
ref
=
ref
,
**
kwargs
)
class
TestCorrConv3d
(
BaseTestConv3d
):
@classmethod
def
setup_class
(
cls
):
if
theano
.
config
.
blas
.
ldflags
==
""
:
raise
SkipTest
()
BaseTestConv3d
.
setup_class
()
def
tcase
(
self
,
i
,
f
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
,
1
)):
if
b
not
in
((
0
,
0
,
0
),
'valid'
):
raise
SkipTest
(
"Only border_mode valid is implemented for basic cpu Conv3D."
)
if
fd
!=
(
1
,
1
,
1
):
raise
SkipTest
(
"No dilation implementation for basic cpu Conv3D."
)
o
=
self
.
get_output_shape
(
i
,
f
,
s
,
b
,
fd
)
if
(
not
theano
.
config
.
blas
.
ldflags
or
not
theano
.
config
.
cxx
or
theano
.
config
.
mode
==
"FAST_COMPILE"
):
raise
SkipTest
(
"Need blas to test conv3d"
)
self
.
run_fwd
(
inputs_shape
=
i
,
filters_shape
=
f
,
subsample
=
s
,
verify_grad
=
True
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
Conv3D
,
check_trace
=
True
,
filter_dilation
=
fd
)
self
.
run_gradweight
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
ConvGrad3D
,
check_trace
=
True
,
filter_dilation
=
fd
)
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
ConvTransp3D
,
check_trace
=
True
,
filter_dilation
=
fd
)
class
TestCpuConv3d
(
BaseTestConv3d
):
@classmethod
def
setup
(
cls
):
BaseTestConv3d
.
setup_class
()
# TODO check how conv_gemm works for conv3d
cls
.
mode
=
theano
.
compile
.
mode
.
get_default_mode
()
.
excluding
(
'conv_gemm'
)
cls
.
opt_err
=
theano
.
config
.
on_opt_error
theano
.
config
.
on_opt_error
=
'ignore'
@classmethod
def
tearDown
(
cls
):
theano
.
config
.
on_opt_error
=
cls
.
opt_err
def
tcase
(
self
,
i
,
f
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
,
1
)):
if
fd
!=
(
1
,
1
,
1
):
raise
SkipTest
(
"No dilation implementation for basic cpu Conv3D."
)
mode
=
self
.
mode
o
=
self
.
get_output_shape
(
i
,
f
,
s
,
b
,
fd
)
fwd_OK
=
True
gradweight_OK
=
True
gradinput_OK
=
True
if
b
not
in
((
0
,
0
,
0
),
'valid'
):
fwd_OK
=
False
gradweight_OK
=
False
gradinput_OK
=
False
if
fwd_OK
:
if
not
theano
.
config
.
blas
.
ldflags
:
raise
SkipTest
(
"Need blas to test conv3d"
)
self
.
run_fwd
(
inputs_shape
=
i
,
filters_shape
=
f
,
subsample
=
s
,
verify_grad
=
(
gradweight_OK
and
gradinput_OK
),
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
Conv3D
,
check_trace
=
True
,
filter_dilation
=
fd
)
else
:
assert_raises
(
AssertionError
,
self
.
run_fwd
,
inputs_shape
=
i
,
filters_shape
=
f
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
check_trace
=
True
,
filter_dilation
=
fd
)
if
gradweight_OK
:
if
not
theano
.
config
.
blas
.
ldflags
:
raise
SkipTest
(
"Need blas to test conv3d"
)
self
.
run_gradweight
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
ConvGrad3D
,
check_trace
=
True
,
filter_dilation
=
fd
)
else
:
assert_raises
(
AssertionError
,
self
.
run_gradweight
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
check_trace
=
True
,
filter_dilation
=
fd
)
if
gradinput_OK
:
if
not
theano
.
config
.
blas
.
ldflags
:
raise
SkipTest
(
"Need blas to test conv3d"
)
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
ConvTransp3D
,
check_trace
=
True
,
filter_dilation
=
fd
)
else
:
assert_raises
(
AssertionError
,
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
check_trace
=
True
,
filter_dilation
=
fd
)
def
test_constant_shapes
():
# Check that the `imshp` and `kshp` parameters of the AbstractConv Ops
# are rejected if not constant or None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论