提交 c1df889d authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Use Corr3dMM as the reference implementation for AbstractConv3d.

上级 401a4dbe
...@@ -133,10 +133,6 @@ class TestCorrMMConv3d(test_abstract_conv.BaseTestConv3d): ...@@ -133,10 +133,6 @@ class TestCorrMMConv3d(test_abstract_conv.BaseTestConv3d):
cls.mode = mode_with_gpu.excluding('cudnn') cls.mode = mode_with_gpu.excluding('cudnn')
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1, 1)): def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1, 1)):
if fd != (1, 1, 1):
# TODO
raise SkipTest("Dilation not supprted by the Conv3D reference implementation.")
mode = self.mode mode = self.mode
o = self.get_output_shape(i, f, s, b, fd) o = self.get_output_shape(i, f, s, b, fd)
self.run_fwd(inputs_shape=i, filters_shape=f, self.run_fwd(inputs_shape=i, filters_shape=f,
......
...@@ -9,7 +9,7 @@ import theano ...@@ -9,7 +9,7 @@ import theano
from theano import tensor from theano import tensor
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr, abstract_conv as conv from theano.tensor.nnet import corr, corr3d, abstract_conv as conv
from theano.tensor.nnet.abstract_conv import get_conv_output_shape from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.nnet.abstract_conv import AbstractConv2d from theano.tensor.nnet.abstract_conv import AbstractConv2d
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradInputs from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradInputs
...@@ -61,101 +61,23 @@ def conv2d_corr_gi(filters, topgrad, inputs_shape, ...@@ -61,101 +61,23 @@ def conv2d_corr_gi(filters, topgrad, inputs_shape,
inputs_shape[2:]) inputs_shape[2:])
def _padding_3d_inputs_to_valid(inputs, filters_shape, border_mode='valid'):
# pad inputs to have valid convolution
if border_mode == 'valid':
border_mode = (0, 0, 0)
elif border_mode == 'full':
border_mode = tuple(f - 1 for f in filters_shape[2:])
elif not isinstance(border_mode, tuple):
raise ValueError('Unsupported border mode', border_mode)
if border_mode == (0, 0, 0):
return inputs
else:
# add padding here, because Conv3D only supports valid convolution
i_shp = inputs.shape
pad = border_mode
inputs_padded = tensor.zeros(dtype=inputs.dtype,
shape=(i_shp[0],
i_shp[1],
i_shp[2] + 2 * pad[0],
i_shp[3] + 2 * pad[1],
i_shp[4] + 2 * pad[2]))
inputs_padded = tensor.set_subtensor(inputs_padded[:, :,
pad[0]:i_shp[2] + pad[0],
pad[1]:i_shp[3] + pad[1],
pad[2]:i_shp[4] + pad[2]],
inputs)
return inputs_padded
def _padding_3d_inputs_shape_to_valid(inputs_shape, filters_shape, border_mode='valid'):
# pad inputs_shape to have valid convolution
if border_mode == 'valid':
border_mode = (0, 0, 0)
elif border_mode == 'full':
border_mode = tuple(f - 1 for f in filters_shape[2:])
elif not isinstance(border_mode, tuple):
raise ValueError('Unsupported border mode', border_mode)
return (inputs_shape[0],
inputs_shape[1],
inputs_shape[2] + 2 * border_mode[0],
inputs_shape[3] + 2 * border_mode[1],
inputs_shape[4] + 2 * border_mode[2])
def _crop_3d_padded_inputs(inputs, filters_shape, border_mode='valid'):
# crop border from padded input
if border_mode == 'valid':
border_mode = (0, 0, 0)
elif border_mode == 'full':
border_mode = tuple(f - 1 for f in filters_shape[2:])
elif not isinstance(border_mode, tuple):
raise ValueError('Unsupported border mode', border_mode)
if border_mode == (0, 0, 0):
return inputs
else:
# crop
i_shp = inputs.shape
pad = border_mode
return inputs[:, :,
pad[0]:i_shp[2] - pad[0],
pad[1]:i_shp[3] - pad[1],
pad[2]:i_shp[4] - pad[2]]
def conv3d_corr(inputs, filters, border_mode="valid", def conv3d_corr(inputs, filters, border_mode="valid",
subsample=(1, 1, 1), conv_mode='conv', subsample=(1, 1, 1), conv_mode='conv',
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1)):
assert filter_dilation == (1, 1, 1)
inputs = _padding_3d_inputs_to_valid(inputs, filters.shape, border_mode)
if conv_mode == 'conv': if conv_mode == 'conv':
filters = filters[:, :, ::-1, ::-1, ::-1] filters = filters[:, :, ::-1, ::-1, ::-1]
return corr3d.Corr3dMM(border_mode,
bias = tensor.zeros_like(filters[:, 0, 0, 0, 0]) subsample,
# Conv3D expects shape (batch, row, column, time, channel) filter_dilation)(inputs, filters)
inputs = inputs.dimshuffle(0, 2, 3, 4, 1)
filters = filters.dimshuffle(0, 2, 3, 4, 1)
rval = Conv3D()(inputs, filters, bias, subsample)
return rval.dimshuffle(0, 4, 1, 2, 3)
def conv3d_corr_gw(inputs, topgrad, filters_shape, def conv3d_corr_gw(inputs, topgrad, filters_shape,
border_mode="valid", subsample=(1, 1, 1), border_mode="valid", subsample=(1, 1, 1),
conv_mode='conv', filter_dilation=(1, 1, 1)): conv_mode='conv', filter_dilation=(1, 1, 1)):
assert filter_dilation == (1, 1, 1) rval = corr3d.Corr3dMM_gradWeights(border_mode,
subsample,
inputs = _padding_3d_inputs_to_valid(inputs, filters_shape, border_mode) filter_dilation)(inputs, topgrad,
filters_shape[2:])
# Conv3D expects shape (batch, row, column, time, channel)
inputs = inputs.dimshuffle(0, 2, 3, 4, 1)
topgrad = topgrad.dimshuffle(0, 2, 3, 4, 1)
filters_shape = tuple(filters_shape[i] for i in (0, 2, 3, 4, 1))
rval = ConvGrad3D()(inputs, subsample, filters_shape, topgrad)
rval = rval.dimshuffle(0, 4, 1, 2, 3)
if conv_mode == 'conv': if conv_mode == 'conv':
rval = rval[:, :, ::-1, ::-1, ::-1] rval = rval[:, :, ::-1, ::-1, ::-1]
return rval return rval
...@@ -164,21 +86,13 @@ def conv3d_corr_gw(inputs, topgrad, filters_shape, ...@@ -164,21 +86,13 @@ def conv3d_corr_gw(inputs, topgrad, filters_shape,
def conv3d_corr_gi(filters, topgrad, inputs_shape, def conv3d_corr_gi(filters, topgrad, inputs_shape,
border_mode="valid", subsample=(1, 1, 1), border_mode="valid", subsample=(1, 1, 1),
conv_mode='conv', filter_dilation=(1, 1, 1)): conv_mode='conv', filter_dilation=(1, 1, 1)):
assert filter_dilation == (1, 1, 1)
inputs_shape = _padding_3d_inputs_shape_to_valid(inputs_shape, filters.shape, border_mode)
if conv_mode == 'conv': if conv_mode == 'conv':
filters = filters[:, :, ::-1, ::-1, ::-1] filters = filters[:, :, ::-1, ::-1, ::-1]
return corr3d.Corr3dMM_gradInputs(border_mode,
# Conv3D expects shape (batch, row, column, time, channel) subsample,
filters_shuffled = filters.dimshuffle(0, 2, 3, 4, 1) filter_dilation)(filters,
topgrad_shuffled = topgrad.dimshuffle(0, 2, 3, 4, 1) topgrad,
inputs_shape_shuffled = tuple(inputs_shape[i] for i in (0, 2, 3, 4, 1)) inputs_shape[2:])
bias = tensor.zeros_like(filters[0, :, 0, 0, 0])
rval = ConvTransp3D()(filters_shuffled, bias, subsample, topgrad_shuffled,
inputs_shape_shuffled[1:4])
rval = rval.dimshuffle(0, 4, 1, 2, 3)
return _crop_3d_padded_inputs(rval, filters.shape, border_mode)
class TestGetConvOutShape(unittest.TestCase): class TestGetConvOutShape(unittest.TestCase):
...@@ -689,7 +603,7 @@ class BaseTestConv3d(BaseTestConv): ...@@ -689,7 +603,7 @@ class BaseTestConv3d(BaseTestConv):
cls.subsamples = [(1, 1, 1), (2, 2, 2), (1, 2, 3)] cls.subsamples = [(1, 1, 1), (2, 2, 2), (1, 2, 3)]
cls.default_subsamples = (1, 1, 1) cls.default_subsamples = (1, 1, 1)
cls.filters_dilations = [(1, 1, 1), (1, 2, 1), (2, 1, 2)] cls.filters_dilations = [(1, 1, 1), (1, 2, 1), (2, 1, 2)]
cls.border_modes = ["valid", "full", (0, 0, 0), (2, 2, 3)] cls.border_modes = ["valid", "half", "full", (0, 0, 0), (2, 2, 3)]
cls.default_border_mode = (0, 0, 0) cls.default_border_mode = (0, 0, 0)
cls.filter_flip = [True, False] cls.filter_flip = [True, False]
cls.default_filter_flip = True cls.default_filter_flip = True
...@@ -737,8 +651,6 @@ class TestCorrConv3d(BaseTestConv3d): ...@@ -737,8 +651,6 @@ class TestCorrConv3d(BaseTestConv3d):
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1, 1)): def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1, 1)):
o = self.get_output_shape(i, f, s, b, fd) o = self.get_output_shape(i, f, s, b, fd)
if fd != (1, 1, 1):
raise SkipTest("No reference implementation for 3D dilation.")
if (not theano.config.blas.ldflags or if (not theano.config.blas.ldflags or
not theano.config.cxx or not theano.config.cxx or
theano.config.mode == "FAST_COMPILE"): theano.config.mode == "FAST_COMPILE"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论