提交 e11121cd authored 作者: affanv14's avatar affanv14 提交者: Arnaud Bergeron

add separable conv3d op

上级 2d3ab3b0
...@@ -662,6 +662,46 @@ def separable_conv2d(input, ...@@ -662,6 +662,46 @@ def separable_conv2d(input,
return pointwise_op return pointwise_op
def separable_conv3d(input,
depthwise_filters,
pointwise_filters,
num_channels,
input_shape=None,
depthwise_filter_shape=None,
pointwise_filter_shape=None,
border_mode='valid',
subsample=(1, 1, 1),
filter_flip=True,
filter_dilation=(1, 1, 1)):
input = as_tensor_variable(input)
depthwise_filters = as_tensor_variable(depthwise_filters)
conv_op = AbstractConv3d(imshp=input_shape,
kshp=depthwise_filter_shape,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation,
num_groups=num_channels)
if input_shape is None or depthwise_filter_shape is None:
depthwise_op_shape = None
else:
depthwise_op_shape = conv_op.infer_shape(None, [input_shape, depthwise_filter_shape])[0]
depthwise_op = conv_op(input, depthwise_filters)
pointwise_op = conv3d(input=depthwise_op,
filters=pointwise_filters,
input_shape=depthwise_op_shape,
filter_shape=pointwise_filter_shape,
border_mode='valid',
subsample=(1, 1, 1),
filter_flip=filter_flip,
filter_dilation=(1, 1, 1),
num_groups=1)
return pointwise_op
def conv3d(input, def conv3d(input,
filters, filters,
input_shape=None, input_shape=None,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论