提交 0b2dac70 authored 作者: affanv14's avatar affanv14

infer shape of depthwise op to give as input to pointwise op

上级 3c9e4d40
......@@ -545,8 +545,8 @@ def conv2d(input,
def separable_conv2d(input,
depthwise_filter,
pointwise_filter,
depthwise_filters,
pointwise_filters,
num_channels,
input_shape=None,
depthwise_filter_shape=None,
......@@ -556,20 +556,22 @@ def separable_conv2d(input,
filter_flip=True,
filter_dilation=(1, 1)):
depthwise_op = conv2d(input=input,
filters=depthwise_filter,
input_shape=input_shape,
filter_shape=depthwise_filter_shape,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation,
num_groups=num_channels)
input = as_tensor_variable(input)
depthwise_filters = as_tensor_variable(depthwise_filters)
conv_op = AbstractConv2d(imshp=input_shape,
kshp=depthwise_filter_shape,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation,
num_groups=num_channels)
depthwise_op_shape = conv_op.infer_shape(None, [input_shape, depthwise_filter_shape])
depthwise_op = conv_op(input, depthwise_filters)
#TODO: infer shape of output pf depthwise_op
pointwise_op = conv2d(input=depthwise_op,
filters=pointwise_filter,
input_shape=None,
filters=pointwise_filters,
input_shape=depthwise_op_shape[0],
filter_shape=pointwise_filter_shape,
border_mode='valid',
subsample=(1, 1),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论