提交 3c9e4d40 authored 作者: affanv14's avatar affanv14

implement very basic version of separable convolutions

上级 743f7aa9
......@@ -35,6 +35,7 @@ import warnings
from .abstract_conv import conv2d as abstract_conv2d
from .abstract_conv import conv2d_grad_wrt_inputs
from .abstract_conv import conv3d
from .abstract_conv import separable_conv2d
def conv2d(input, filters, input_shape=None, filter_shape=None,
......
......@@ -544,6 +544,41 @@ def conv2d(input,
return conv_op(input, filters)
def separable_conv2d(input,
depthwise_filter,
pointwise_filter,
num_channels,
input_shape=None,
depthwise_filter_shape=None,
pointwise_filter_shape=None,
border_mode='valid',
subsample=(1, 1),
filter_flip=True,
filter_dilation=(1, 1)):
depthwise_op = conv2d(input=input,
filters=depthwise_filter,
input_shape=input_shape,
filter_shape=depthwise_filter_shape,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation,
num_groups=num_channels)
#TODO: infer shape of output pf depthwise_op
pointwise_op = conv2d(input=depthwise_op,
filters=pointwise_filter,
input_shape=None,
filter_shape=pointwise_filter_shape,
border_mode='valid',
subsample=(1, 1),
filter_flip=filter_flip,
filter_dilation=(1, 1),
num_groups=1)
return pointwise_op
def conv3d(input,
filters,
input_shape=None,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论