提交 211f0281 authored 作者: affanv14's avatar affanv14

change flops method to include num_groups

上级 2ef18283
......@@ -516,13 +516,13 @@ class BaseGpuCorrMM(CGpuKernelBase):
# flops for any direction, sampling, padding, and border mode
inputs, filters = inp
outputs, = outp
assert inputs[1] == filters[1]
assert inputs[1] == (filters[1] * self.num_groups)
# nb mul and add by output pixel
flops = filters[2] * filters[3] * 2
# nb flops by output image
flops *= outputs[2] * outputs[3]
# nb patch multiplied
flops *= inputs[1] * filters[0] * inputs[0]
flops *= inputs[1] * filters[0] * inputs[0] / self.num_groups
return flops
def c_headers(self):
......@@ -1129,13 +1129,13 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
# flops for any direction, sampling, padding, and border mode
inputs, filters = inp
outputs, = outp
assert inputs[1] == filters[1]
assert inputs[1] == (filters[1] * self.num_groups)
# nb mul and add by output pixel
flops = filters[2] * filters[3] * filters[4] * 2
# nb flops by output image
flops *= outputs[2] * outputs[3] * outputs[4]
# nb patch multiplied
flops *= inputs[1] * filters[0] * inputs[0]
flops *= inputs[1] * filters[0] * inputs[0] / self.num_groups
return flops
def c_headers(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论