提交 c1df889d authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Use Corr3dMM as the reference implementation for AbstractConv3d.

上级 401a4dbe
......@@ -133,10 +133,6 @@ class TestCorrMMConv3d(test_abstract_conv.BaseTestConv3d):
cls.mode = mode_with_gpu.excluding('cudnn')
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1, 1)):
if fd != (1, 1, 1):
# TODO
raise SkipTest("Dilation not supprted by the Conv3D reference implementation.")
mode = self.mode
o = self.get_output_shape(i, f, s, b, fd)
self.run_fwd(inputs_shape=i, filters_shape=f,
......
......@@ -9,7 +9,7 @@ import theano
from theano import tensor
from theano.gof.opt import check_stack_trace
from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr, abstract_conv as conv
from theano.tensor.nnet import corr, corr3d, abstract_conv as conv
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.nnet.abstract_conv import AbstractConv2d
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradInputs
......@@ -61,101 +61,23 @@ def conv2d_corr_gi(filters, topgrad, inputs_shape,
inputs_shape[2:])
def _padding_3d_inputs_to_valid(inputs, filters_shape, border_mode='valid'):
# pad inputs to have valid convolution
if border_mode == 'valid':
border_mode = (0, 0, 0)
elif border_mode == 'full':
border_mode = tuple(f - 1 for f in filters_shape[2:])
elif not isinstance(border_mode, tuple):
raise ValueError('Unsupported border mode', border_mode)
if border_mode == (0, 0, 0):
return inputs
else:
# add padding here, because Conv3D only supports valid convolution
i_shp = inputs.shape
pad = border_mode
inputs_padded = tensor.zeros(dtype=inputs.dtype,
shape=(i_shp[0],
i_shp[1],
i_shp[2] + 2 * pad[0],
i_shp[3] + 2 * pad[1],
i_shp[4] + 2 * pad[2]))
inputs_padded = tensor.set_subtensor(inputs_padded[:, :,
pad[0]:i_shp[2] + pad[0],
pad[1]:i_shp[3] + pad[1],
pad[2]:i_shp[4] + pad[2]],
inputs)
return inputs_padded
def _padding_3d_inputs_shape_to_valid(inputs_shape, filters_shape, border_mode='valid'):
# pad inputs_shape to have valid convolution
if border_mode == 'valid':
border_mode = (0, 0, 0)
elif border_mode == 'full':
border_mode = tuple(f - 1 for f in filters_shape[2:])
elif not isinstance(border_mode, tuple):
raise ValueError('Unsupported border mode', border_mode)
return (inputs_shape[0],
inputs_shape[1],
inputs_shape[2] + 2 * border_mode[0],
inputs_shape[3] + 2 * border_mode[1],
inputs_shape[4] + 2 * border_mode[2])
def _crop_3d_padded_inputs(inputs, filters_shape, border_mode='valid'):
# crop border from padded input
if border_mode == 'valid':
border_mode = (0, 0, 0)
elif border_mode == 'full':
border_mode = tuple(f - 1 for f in filters_shape[2:])
elif not isinstance(border_mode, tuple):
raise ValueError('Unsupported border mode', border_mode)
if border_mode == (0, 0, 0):
return inputs
else:
# crop
i_shp = inputs.shape
pad = border_mode
return inputs[:, :,
pad[0]:i_shp[2] - pad[0],
pad[1]:i_shp[3] - pad[1],
pad[2]:i_shp[4] - pad[2]]
def conv3d_corr(inputs, filters, border_mode="valid",
subsample=(1, 1, 1), conv_mode='conv',
filter_dilation=(1, 1, 1)):
assert filter_dilation == (1, 1, 1)
inputs = _padding_3d_inputs_to_valid(inputs, filters.shape, border_mode)
if conv_mode == 'conv':
filters = filters[:, :, ::-1, ::-1, ::-1]
bias = tensor.zeros_like(filters[:, 0, 0, 0, 0])
# Conv3D expects shape (batch, row, column, time, channel)
inputs = inputs.dimshuffle(0, 2, 3, 4, 1)
filters = filters.dimshuffle(0, 2, 3, 4, 1)
rval = Conv3D()(inputs, filters, bias, subsample)
return rval.dimshuffle(0, 4, 1, 2, 3)
return corr3d.Corr3dMM(border_mode,
subsample,
filter_dilation)(inputs, filters)
def conv3d_corr_gw(inputs, topgrad, filters_shape,
border_mode="valid", subsample=(1, 1, 1),
conv_mode='conv', filter_dilation=(1, 1, 1)):
assert filter_dilation == (1, 1, 1)
inputs = _padding_3d_inputs_to_valid(inputs, filters_shape, border_mode)
# Conv3D expects shape (batch, row, column, time, channel)
inputs = inputs.dimshuffle(0, 2, 3, 4, 1)
topgrad = topgrad.dimshuffle(0, 2, 3, 4, 1)
filters_shape = tuple(filters_shape[i] for i in (0, 2, 3, 4, 1))
rval = ConvGrad3D()(inputs, subsample, filters_shape, topgrad)
rval = rval.dimshuffle(0, 4, 1, 2, 3)
rval = corr3d.Corr3dMM_gradWeights(border_mode,
subsample,
filter_dilation)(inputs, topgrad,
filters_shape[2:])
if conv_mode == 'conv':
rval = rval[:, :, ::-1, ::-1, ::-1]
return rval
......@@ -164,21 +86,13 @@ def conv3d_corr_gw(inputs, topgrad, filters_shape,
def conv3d_corr_gi(filters, topgrad, inputs_shape,
border_mode="valid", subsample=(1, 1, 1),
conv_mode='conv', filter_dilation=(1, 1, 1)):
assert filter_dilation == (1, 1, 1)
inputs_shape = _padding_3d_inputs_shape_to_valid(inputs_shape, filters.shape, border_mode)
if conv_mode == 'conv':
filters = filters[:, :, ::-1, ::-1, ::-1]
# Conv3D expects shape (batch, row, column, time, channel)
filters_shuffled = filters.dimshuffle(0, 2, 3, 4, 1)
topgrad_shuffled = topgrad.dimshuffle(0, 2, 3, 4, 1)
inputs_shape_shuffled = tuple(inputs_shape[i] for i in (0, 2, 3, 4, 1))
bias = tensor.zeros_like(filters[0, :, 0, 0, 0])
rval = ConvTransp3D()(filters_shuffled, bias, subsample, topgrad_shuffled,
inputs_shape_shuffled[1:4])
rval = rval.dimshuffle(0, 4, 1, 2, 3)
return _crop_3d_padded_inputs(rval, filters.shape, border_mode)
return corr3d.Corr3dMM_gradInputs(border_mode,
subsample,
filter_dilation)(filters,
topgrad,
inputs_shape[2:])
class TestGetConvOutShape(unittest.TestCase):
......@@ -689,7 +603,7 @@ class BaseTestConv3d(BaseTestConv):
cls.subsamples = [(1, 1, 1), (2, 2, 2), (1, 2, 3)]
cls.default_subsamples = (1, 1, 1)
cls.filters_dilations = [(1, 1, 1), (1, 2, 1), (2, 1, 2)]
cls.border_modes = ["valid", "full", (0, 0, 0), (2, 2, 3)]
cls.border_modes = ["valid", "half", "full", (0, 0, 0), (2, 2, 3)]
cls.default_border_mode = (0, 0, 0)
cls.filter_flip = [True, False]
cls.default_filter_flip = True
......@@ -737,8 +651,6 @@ class TestCorrConv3d(BaseTestConv3d):
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1, 1)):
o = self.get_output_shape(i, f, s, b, fd)
if fd != (1, 1, 1):
raise SkipTest("No reference implementation for 3D dilation.")
if (not theano.config.blas.ldflags or
not theano.config.cxx or
theano.config.mode == "FAST_COMPILE"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论