提交 c0bea2c6 authored 作者: affanv14's avatar affanv14

change flops to support grouped convolution

上级 e9796582
......@@ -1495,20 +1495,17 @@ class BaseAbstractConv(Op):
def flops(self, inp, outp):
""" Useful with the hack in profiling to print the MFlops"""
if self.convdim == 2:
if self.num_groups > 1:
raise NotImplementedError(
'flops not implemented for grouped convolution')
# if the output shape is correct, then this gives the correct
# flops for any direction, sampling, padding, and border mode
inputs, filters = inp
outputs, = outp
assert inputs[1] == filters[1]
assert inputs[1] == (filters[1] * self.num_groups)
# nb mul and add by output pixel
flops = filters[2] * filters[3] * 2
# nb flops by output image
flops *= outputs[2] * outputs[3]
# nb patch multiplied
flops *= inputs[1] * filters[0] * inputs[0]
flops *= inputs[1] * filters[0] * inputs[0] / self.num_groups
return flops
else:
# TODO implement for convdim == 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论