提交 1bbe21ba authored 作者: Vikram's avatar Vikram

A few suggestions implemented

上级 717f2534
...@@ -138,14 +138,11 @@ def get_conv_shape_1axis(image_shape, kernel_shape, border_mode, ...@@ -138,14 +138,11 @@ def get_conv_shape_1axis(image_shape, kernel_shape, border_mode,
# In case of symbolic shape, we want to build the smallest graph # In case of symbolic shape, we want to build the smallest graph
# (image_shape + 2 * pad - dil_kernel_shape) // subsample + 1 # (image_shape + 2 * pad - dil_kernel_shape) // subsample + 1
if pad_l == 0 and pad_r == 0: out_shp = (image_shape - dil_kernel_shape)
out_shp = (image_shape - dil_kernel_shape) if pad_l > 0:
elif pad_l == 0: out_shp += pad_l
out_shp = (image_shape + pad_l - dil_kernel_shape) if pad_r > 0:
elif pad_r == 0: out_shp += pad_r
out_shp = (image_shape + pad_r - dil_kernel_shape)
else:
out_shp = (image_shape + pad_l + pad_r - dil_kernel_shape)
if subsample != 1: if subsample != 1:
out_shp = out_shp // subsample out_shp = out_shp // subsample
out_shp = out_shp + 1 out_shp = out_shp + 1
...@@ -259,14 +256,12 @@ def get_conv_gradweights_shape_1axis(image_shape, top_shape, border_mode, ...@@ -259,14 +256,12 @@ def get_conv_gradweights_shape_1axis(image_shape, top_shape, border_mode,
elif border_mode == "valid": elif border_mode == "valid":
kernel_shape = image_shape - top_shape kernel_shape = image_shape - top_shape
else: else:
if isinstance(border_mode, integer_types): if isinstance(border_mode, tuple):
if border_mode < 0:
raise ValueError("border_mode must be >= 0")
pad_l = pad_r = border_mode
elif isinstance(border_mode, tuple):
if min(border_mode) < 0:
raise ValueError("border_mode must be >= 0")
pad_l, pad_r = border_mode pad_l, pad_r = border_mode
else:
pad_l = pad_r = border_mode
if pad_l < 0 or pad_r < 0:
raise ValueError("border_mode must be >= 0")
kernel_shape = (image_shape + pad_l + pad_r - top_shape) kernel_shape = (image_shape + pad_l + pad_r - top_shape)
...@@ -393,14 +388,11 @@ def get_conv_gradinputs_shape_1axis(kernel_shape, top_shape, border_mode, ...@@ -393,14 +388,11 @@ def get_conv_gradinputs_shape_1axis(kernel_shape, top_shape, border_mode,
# In case of symbolic shape, we want to build the smallest graph # In case of symbolic shape, we want to build the smallest graph
# image_shape = (top_shape - 1) * s - 2 * pad + dil_kernel_shape + a # image_shape = (top_shape - 1) * s - 2 * pad + dil_kernel_shape + a
# where 0 <= a < subsample, but we have checked that subsample == 1 # where 0 <= a < subsample, but we have checked that subsample == 1
if pad_l == 0 and pad_r == 0: image_shape = (top_shape + dil_kernel_shape - 1)
image_shape = (top_shape + dil_kernel_shape - 1) if pad_l > 0:
elif pad_l == 0: image_shape -= pad_l
image_shape = (top_shape - pad_l + dil_kernel_shape - 1) if pad_r > 0:
elif pad_r == 0: image_shape -= pad_r
image_shape = (top_shape - pad_r + dil_kernel_shape - 1)
else:
image_shape = (top_shape - pad_l - pad_r + dil_kernel_shape - 1)
return image_shape return image_shape
...@@ -1505,22 +1497,23 @@ def conv3d_grad_wrt_weights(input, ...@@ -1505,22 +1497,23 @@ def conv3d_grad_wrt_weights(input,
return gradWeight_op(input, output_grad, filter_shape[-3:]) return gradWeight_op(input, output_grad, filter_shape[-3:])
def dilated_causal_conv(input, def causal_conv(input,
filters, filters,
filter_shape, filter_shape,
input_shape=None, input_shape=None,
subsample=1, subsample=1,
filter_flip=True, filter_flip=True,
filter_dilation=1, filter_dilation=1,
num_groups=1): num_groups=1,
unshared=False):
input = as_tensor_variable(input) input = as_tensor_variable(input)
filters = as_tensor_variable(filters) filters = as_tensor_variable(filters)
if input.ndim != 3: if input.ndim != 3:
raise ValueError('Input should be 3D for Dilated Causal convolution.') raise ValueError('Input should be 3D for causal convolution.')
if filters.ndim != 3: if filters.ndim != 3:
raise ValueError('Filters should be 3D for Dilated Causal convolution') raise ValueError('Filters should be 3D for causal convolution')
input = input.dimshuffle(0, 1, 2, 'x') input = input.dimshuffle(0, 1, 2, 'x')
filters = filters.dimshuffle(0, 1, 2, 'x') filters = filters.dimshuffle(0, 1, 2, 'x')
...@@ -1546,11 +1539,10 @@ def dilated_causal_conv(input, ...@@ -1546,11 +1539,10 @@ def dilated_causal_conv(input,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation, filter_dilation=filter_dilation,
num_groups=num_groups, num_groups=num_groups,
unshared=False) unshared=unshared)
output = conv_op(input, filters) output = conv_op(input, filters)
shape = output.shape[:-1] return output[:, :, :, 0]
return output.reshape(shape)
def bilinear_kernel_2D(ratio, normalize=True): def bilinear_kernel_2D(ratio, normalize=True):
...@@ -1824,25 +1816,26 @@ class BaseAbstractConv(Op): ...@@ -1824,25 +1816,26 @@ class BaseAbstractConv(Op):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be a ' 'invalid border_mode {}, which must be a '
'non-negative integer'.format(border_mode)) 'non-negative integer'.format(border_mode))
border_mode = ((border_mode, border_mode),) * convdim border_mode = (border_mode,) * convdim
elif isinstance(border_mode, tuple): elif isinstance(border_mode, tuple):
if len(border_mode) != convdim: if len(border_mode) != convdim:
raise ValueError( raise ValueError(
'invalid border_mode {} which must be a ' 'invalid border_mode {}, which must be a '
'tuple of length {}'.format(border_mode, convdim)) 'tuple of length {}'.format(border_mode, convdim))
for mode in border_mode: for mode in border_mode:
if not((isinstance(mode, integer_types) and mode >= 0) or if not((isinstance(mode, integer_types) and mode >= 0) or
(isinstance(mode, tuple) and len(mode) == 2 and (isinstance(mode, tuple) and len(mode) == 2 and min(mode) >= 0 and
min(mode) >= 0)): all(isinstance(m, integer_types) for m in mode))):
raise ValueError( raise ValueError(
'invalid border mode {}. The tuple can only contain ' 'invalid border mode {}. The tuple can only contain integers '
'integers or tuples of length 2'.format(border_mode)) ' or tuples of integers of length 2'.format(border_mode))
elif border_mode not in ('valid', 'full', 'half'): elif border_mode not in ('valid', 'full', 'half'):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be either ' 'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a tuple ' '"valid", "full", "half", an integer or a tuple '
'of length {}'.format(border_mode, convdim)) 'of length {}'.format(border_mode, convdim))
if all(mode == (0, 0) or mode == 0 for mode in border_mode): if isinstance(border_mode, tuple) and \
all(mode == (0, 0) or mode == 0 for mode in border_mode):
border_mode = 'valid' border_mode = 'valid'
self.imshp = tuple(imshp) if imshp else (None,) * (2 + convdim) self.imshp = tuple(imshp) if imshp else (None,) * (2 + convdim)
...@@ -2113,9 +2106,9 @@ class AbstractConv(BaseAbstractConv): ...@@ -2113,9 +2106,9 @@ class AbstractConv(BaseAbstractConv):
for m in mode: for m in mode:
if isinstance(m, integer_types) and m >= 0: if isinstance(m, integer_types) and m >= 0:
border += ((m, m),) border += ((m, m),)
elif isinstance(m, tuple) and len(m) == 2 and \ elif isinstance(m, tuple) and len(m) == 2 and min(m) >= 0 and \
min(m) >= 0: all(isinstance(b, integer_types) for b in m):
border += ((int(m[0]), int(m[1])),) border += ((m[0], m[1]),)
else: else:
raise ValueError( raise ValueError(
'invalid border mode {}. The tuple can only contain ' 'invalid border mode {}. The tuple can only contain '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论