提交 717f2534 authored 作者: Vikram's avatar Vikram

Dilated causal wrapper fn. float32 error fixed

上级 e514f4d3
......@@ -1505,6 +1505,54 @@ def conv3d_grad_wrt_weights(input,
return gradWeight_op(input, output_grad, filter_shape[-3:])
def dilated_causal_conv(input,
filters,
filter_shape,
input_shape=None,
subsample=1,
filter_flip=True,
filter_dilation=1,
num_groups=1):
input = as_tensor_variable(input)
filters = as_tensor_variable(filters)
if input.ndim != 3:
raise ValueError('Input should be 3D for Dilated Causal convolution.')
if filters.ndim != 3:
raise ValueError('Filters should be 3D for Dilated Causal convolution')
input = input.dimshuffle(0, 1, 2, 'x')
filters = filters.dimshuffle(0, 1, 2, 'x')
if input_shape is not None:
assert(len(input_shape) == 3)
input_shape = tuple(input_shape)
input_shape += (1,)
assert(len(filter_shape) == 3)
filter_shape = tuple(filter_shape)
filter_shape += (1,)
left_pad = filter_dilation * (filter_shape[2] - 1)
subsample = (subsample, 1)
filter_dilation = (filter_dilation, 1)
conv_op = AbstractConv2d(imshp=input_shape,
kshp=filter_shape,
border_mode=((left_pad, 0), 0),
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation,
num_groups=num_groups,
unshared=False)
output = conv_op(input, filters)
shape = output.shape[:-1]
return output.reshape(shape)
def bilinear_kernel_2D(ratio, normalize=True):
"""Compute 2D kernel for bilinear upsampling
......
......@@ -24,6 +24,7 @@ from theano.tensor.nnet.abstract_conv import bilinear_kernel_1D
from theano.tensor.nnet.abstract_conv import bilinear_kernel_2D
from theano.tensor.nnet.abstract_conv import bilinear_upsampling
from theano.tensor.nnet.abstract_conv import separable_conv2d, separable_conv3d
from theano.tensor.nnet.abstract_conv import dilated_causal_conv
from theano.tensor.nnet.corr import (CorrMM, CorrMM_gradWeights,
CorrMM_gradInputs)
from theano.tensor.nnet.corr3d import (Corr3dMM, Corr3dMM_gradWeights,
......@@ -1909,8 +1910,8 @@ class TestAsymmetricPadding(unittest.TestCase):
imshp = (3, 2, 4, 4)
kshp = (4, 2, 2, 2)
topshp = (3, 4, 6, 5)
pad = ((1, 2), (1, 1))
topshp = (3, 4, 6, 6)
pad = ((1, 2), (2, 1))
def test_fwd(self):
img_sym = theano.tensor.tensor4('img')
......@@ -1936,7 +1937,7 @@ class TestAsymmetricPadding(unittest.TestCase):
self.imshp[2] + self.pad[0][0] + self.pad[0][1],
self.imshp[3] + self.pad[1][0] + self.pad[1][1])
exp_img = np.zeros(exp_imshp)
exp_img = np.zeros(exp_imshp, dtype=theano.config.floatX)
exp_img[:, :, self.pad[0][0]:self.imshp[2] + self.pad[0][0],
self.pad[1][0]:self.imshp[3] + self.pad[1][0]] = img
ref_output = ref_func(exp_img, kern)
......@@ -1969,7 +1970,7 @@ class TestAsymmetricPadding(unittest.TestCase):
self.imshp[2] + self.pad[0][0] + self.pad[0][1],
self.imshp[3] + self.pad[1][0] + self.pad[1][1])
exp_img = np.zeros(exp_imshp)
exp_img = np.zeros(exp_imshp, dtype=theano.config.floatX)
exp_img[:, :, self.pad[0][0]:self.imshp[2] + self.pad[0][0],
self.pad[1][0]:self.imshp[3] + self.pad[1][0]] = img
ref_output = ref_func(exp_img, top)
......@@ -2014,3 +2015,31 @@ class TestAsymmetricPadding(unittest.TestCase):
return asymmetric_conv_op(filters_val, output_val, tensor.as_tensor_variable(self.imshp[-2:]))
utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1)
class TestDilatedCausalConv(unittest.TestCase):
mode = theano.compile.mode.Mode(optimizer='None')
imshp = (3, 2, 5)
kshp = (2, 2, 3)
topshp = (3, 2, 5)
def test_interface(self):
img_sym = theano.tensor.tensor3('img')
kern_sym = theano.tensor.tensor3('kern')
img = np.random.random(self.imshp).astype(theano.config.floatX)
kern = np.random.random(self.kshp).astype(theano.config.floatX)
sym_out = dilated_causal_conv(img_sym, kern_sym, self.kshp, filter_dilation=1)
causal_func = theano.function([img_sym, kern_sym], sym_out, mode=self.mode)
output = causal_func(img, kern)
assert output.shape == self.topshp
# def causal_conv(inputs_val, filters_val):
# return dilated_causal_conv(inputs_val, filters_val, self.kshp, filter_dilation=1)
# utt.verify_grad(causal_conv, [img, kern], mode=self.mode, eps=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论