提交 c0bea2c6 authored 作者: affanv14's avatar affanv14

change flops to support grouped convolution

上级 e9796582
...@@ -1495,20 +1495,17 @@ class BaseAbstractConv(Op): ...@@ -1495,20 +1495,17 @@ class BaseAbstractConv(Op):
def flops(self, inp, outp): def flops(self, inp, outp):
""" Useful with the hack in profiling to print the MFlops""" """ Useful with the hack in profiling to print the MFlops"""
if self.convdim == 2: if self.convdim == 2:
if self.num_groups > 1:
raise NotImplementedError(
'flops not implemented for grouped convolution')
# if the output shape is correct, then this gives the correct # if the output shape is correct, then this gives the correct
# flops for any direction, sampling, padding, and border mode # flops for any direction, sampling, padding, and border mode
inputs, filters = inp inputs, filters = inp
outputs, = outp outputs, = outp
assert inputs[1] == filters[1] assert inputs[1] == (filters[1] * self.num_groups)
# nb mul and add by output pixel # nb mul and add by output pixel
flops = filters[2] * filters[3] * 2 flops = filters[2] * filters[3] * 2
# nb flops by output image # nb flops by output image
flops *= outputs[2] * outputs[3] flops *= outputs[2] * outputs[3]
# nb patch multiplied # nb patch multiplied
flops *= inputs[1] * filters[0] * inputs[0] flops *= inputs[1] * filters[0] * inputs[0] / self.num_groups
return flops return flops
else: else:
# TODO implement for convdim == 3 # TODO implement for convdim == 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论