提交 0b2dac70 authored 作者: affanv14's avatar affanv14

infer shape of depthwise op to give as input to pointwise op

上级 3c9e4d40
...@@ -545,8 +545,8 @@ def conv2d(input, ...@@ -545,8 +545,8 @@ def conv2d(input,
def separable_conv2d(input, def separable_conv2d(input,
depthwise_filter, depthwise_filters,
pointwise_filter, pointwise_filters,
num_channels, num_channels,
input_shape=None, input_shape=None,
depthwise_filter_shape=None, depthwise_filter_shape=None,
...@@ -556,20 +556,22 @@ def separable_conv2d(input, ...@@ -556,20 +556,22 @@ def separable_conv2d(input,
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1)):
depthwise_op = conv2d(input=input, input = as_tensor_variable(input)
filters=depthwise_filter, depthwise_filters = as_tensor_variable(depthwise_filters)
input_shape=input_shape, conv_op = AbstractConv2d(imshp=input_shape,
filter_shape=depthwise_filter_shape, kshp=depthwise_filter_shape,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation, filter_dilation=filter_dilation,
num_groups=num_channels) num_groups=num_channels)
#TODO: infer shape of output pf depthwise_op depthwise_op_shape = conv_op.infer_shape(None, [input_shape, depthwise_filter_shape])
depthwise_op = conv_op(input, depthwise_filters)
pointwise_op = conv2d(input=depthwise_op, pointwise_op = conv2d(input=depthwise_op,
filters=pointwise_filter, filters=pointwise_filters,
input_shape=None, input_shape=depthwise_op_shape[0],
filter_shape=pointwise_filter_shape, filter_shape=pointwise_filter_shape,
border_mode='valid', border_mode='valid',
subsample=(1, 1), subsample=(1, 1),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论