提交 2e0d1d1f authored 作者: affanv14's avatar affanv14

modify setstate to include num_groups

上级 c012bcdf
...@@ -542,6 +542,8 @@ class GpuDnnConv(DnnBase): ...@@ -542,6 +542,8 @@ class GpuDnnConv(DnnBase):
self.algo = config.dnn.conv.algo_fwd self.algo = config.dnn.conv.algo_fwd
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def make_node(self, img, kern, output, desc, alpha=None, beta=None): def make_node(self, img, kern, output, desc, alpha=None, beta=None):
ctx_name = infer_context_name(img, kern, output) ctx_name = infer_context_name(img, kern, output)
...@@ -675,6 +677,8 @@ class GpuDnnConvGradW(DnnBase): ...@@ -675,6 +677,8 @@ class GpuDnnConvGradW(DnnBase):
self.inplace = False self.inplace = False
if not hasattr(self, 'algo'): if not hasattr(self, 'algo'):
self.algo = config.dnn.conv.algo_bwd_filter self.algo = config.dnn.conv.algo_bwd_filter
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def grad(self, inp, grads): def grad(self, inp, grads):
img, top, output, desc, alpha, beta = inp img, top, output, desc, alpha, beta = inp
...@@ -806,6 +810,8 @@ class GpuDnnConvGradI(DnnBase): ...@@ -806,6 +810,8 @@ class GpuDnnConvGradI(DnnBase):
self.algo = config.dnn.conv.algo_bwd_data self.algo = config.dnn.conv.algo_bwd_data
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def grad(self, inp, grads): def grad(self, inp, grads):
kerns, top, output, desc, alpha, beta = inp kerns, top, output, desc, alpha, beta = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论