提交 9041a214 authored 作者: affanv14's avatar affanv14

skip optimization if num_groups > 1

上级 1850ac36
......@@ -1640,8 +1640,9 @@ def local_abstractconv_gemm_alternative(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'full' and subsample == (1, 1):
if border_mode == 'full' and subsample == (1, 1) and num_groups == 1:
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1]
......@@ -1651,7 +1652,8 @@ def local_abstractconv_gemm_alternative(node):
filter_dilation)(
gpu_contiguous(kern), gpu_contiguous(img))
elif border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1):
elif (border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1) and
num_groups == 1):
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1]
......@@ -1771,8 +1773,10 @@ def local_abstractconv_gemm_gradweights_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1):
if(border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1) and
num_groups == 1):
rval = GpuCorrMM(border_mode,
subsample,
filter_dilation)(
......@@ -1842,8 +1846,9 @@ def local_abstractconv_gradinputs_gemm_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1):
if border_mode == 'valid' and subsample == (1, 1) and num_groups == 1:
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论