提交 9041a214 authored 作者: affanv14's avatar affanv14

skip optimization if num_groups > 1

上级 1850ac36
...@@ -1640,8 +1640,9 @@ def local_abstractconv_gemm_alternative(node): ...@@ -1640,8 +1640,9 @@ def local_abstractconv_gemm_alternative(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'full' and subsample == (1, 1): if border_mode == 'full' and subsample == (1, 1) and num_groups == 1:
if not node.op.filter_flip: if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
...@@ -1651,7 +1652,8 @@ def local_abstractconv_gemm_alternative(node): ...@@ -1651,7 +1652,8 @@ def local_abstractconv_gemm_alternative(node):
filter_dilation)( filter_dilation)(
gpu_contiguous(kern), gpu_contiguous(img)) gpu_contiguous(kern), gpu_contiguous(img))
elif border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1): elif (border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1) and
num_groups == 1):
if node.op.filter_flip: if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
...@@ -1771,8 +1773,10 @@ def local_abstractconv_gemm_gradweights_alt(node): ...@@ -1771,8 +1773,10 @@ def local_abstractconv_gemm_gradweights_alt(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1): if(border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1) and
num_groups == 1):
rval = GpuCorrMM(border_mode, rval = GpuCorrMM(border_mode,
subsample, subsample,
filter_dilation)( filter_dilation)(
...@@ -1842,8 +1846,9 @@ def local_abstractconv_gradinputs_gemm_alt(node): ...@@ -1842,8 +1846,9 @@ def local_abstractconv_gradinputs_gemm_alt(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1): if border_mode == 'valid' and subsample == (1, 1) and num_groups == 1:
if not node.op.filter_flip: if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论