提交 a1bd673d authored 作者: affanv14's avatar affanv14

check num_groups is 1 before trying alternative passes

上级 68dac70d
...@@ -978,7 +978,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1), ...@@ -978,7 +978,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None) fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
ctx_name = infer_context_name(img, kerns) ctx_name = infer_context_name(img, kerns)
if (border_mode == 'valid' and subsample == (1, 1) and dilation == (1, 1) and if (border_mode == 'valid' and subsample == (1, 1) and dilation == (1, 1) and
direction_hint == 'bprop weights'): direction_hint == 'bprop weights' and num_groups == 1):
# Special case: We are asked to use GpuDnnConvGradW. We need to set # Special case: We are asked to use GpuDnnConvGradW. We need to set
# up a suitable 'fake' convolution to compute the gradient for. # up a suitable 'fake' convolution to compute the gradient for.
img = gpu_contiguous(img.dimshuffle(1, 0, 2, 3)) img = gpu_contiguous(img.dimshuffle(1, 0, 2, 3))
...@@ -999,7 +999,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1), ...@@ -999,7 +999,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3), ctx_name) return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3), ctx_name)
elif (border_mode == 'full' and subsample == (1, 1) and dilation == (1, 1) and elif (border_mode == 'full' and subsample == (1, 1) and dilation == (1, 1) and
direction_hint != 'forward!'): direction_hint != 'forward!' and num_groups == 1):
# Special case: We can be faster by using GpuDnnConvGradI to compute # Special case: We can be faster by using GpuDnnConvGradI to compute
# the full convolution as the backward pass of a valid convolution. # the full convolution as the backward pass of a valid convolution.
# We just need to set up a suitable 'fake' valid convolution. # We just need to set up a suitable 'fake' valid convolution.
......
...@@ -1568,7 +1568,8 @@ def local_abstractconv_gemm(node): ...@@ -1568,7 +1568,8 @@ def local_abstractconv_gemm(node):
(None not in node.op.imshp[-2:]) and (None not in node.op.imshp[-2:]) and
(node.op.kshp is not None) and (node.op.kshp is not None) and
(None not in node.op.kshp) and (None not in node.op.kshp) and
border_mode != "half"): border_mode != "half" and
node.op.num_groups == 1):
# we know the kernel and output size # we know the kernel and output size
prod1 = node.op.kshp[0] * node.op.kshp[1] prod1 = node.op.kshp[0] * node.op.kshp[1]
prod2 = ((node.op.imshp[-2] - node.op.kshp[0] + 1) * prod2 = ((node.op.imshp[-2] - node.op.kshp[0] + 1) *
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论