提交 2a7b2c81 authored 作者: affanv14's avatar affanv14

make optimisers pass num_groups parameter

上级 5d27d984
...@@ -1684,7 +1684,8 @@ def local_abstractconv3d_gemm(node): ...@@ -1684,7 +1684,8 @@ def local_abstractconv3d_gemm(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
if ((border_mode == 'full') and (subsample == (1, 1, 1))): num_groups = node.op.num_groups
if ((border_mode == 'full') and (subsample == (1, 1, 1)) and num_groups == 1):
if not node.op.filter_flip: if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1, ::-1]
# need to dimshuffle the kernel for full convolution # need to dimshuffle the kernel for full convolution
...@@ -1701,8 +1702,9 @@ def local_abstractconv3d_gemm(node): ...@@ -1701,8 +1702,9 @@ def local_abstractconv3d_gemm(node):
# By default use GpuCorr3dMM # By default use GpuCorr3dMM
rval = GpuCorr3dMM(border_mode, rval = GpuCorr3dMM(border_mode,
subsample, subsample,
filter_dilation)(gpu_contiguous(img), filter_dilation,
gpu_contiguous(kern)) num_groups)(gpu_contiguous(img),
gpu_contiguous(kern))
# call GpuCorr3dMM_gradWeights if good # call GpuCorr3dMM_gradWeights if good
# (the latter is faster if batchsize * kernelHeight * kernelWidth * kernelDepth # (the latter is faster if batchsize * kernelHeight * kernelWidth * kernelDepth
...@@ -1714,7 +1716,8 @@ def local_abstractconv3d_gemm(node): ...@@ -1714,7 +1716,8 @@ def local_abstractconv3d_gemm(node):
(None not in node.op.imshp[-3:]) and (None not in node.op.imshp[-3:]) and
(node.op.kshp is not None) and (node.op.kshp is not None) and
(None not in node.op.kshp) and (None not in node.op.kshp) and
border_mode != "half"): border_mode != "half" and
num_groups == 1):
# we know the kernel and output size # we know the kernel and output size
prod1 = node.op.kshp[0] * node.op.kshp[1] * node.op.kshp[2] prod1 = node.op.kshp[0] * node.op.kshp[1] * node.op.kshp[2]
prod2 = ((node.op.imshp[-3] - node.op.kshp[0] + 1) * prod2 = ((node.op.imshp[-3] - node.op.kshp[0] + 1) *
...@@ -1906,7 +1909,8 @@ def local_abstractconv3d_gradweights_gemm(node): ...@@ -1906,7 +1909,8 @@ def local_abstractconv3d_gradweights_gemm(node):
rval = GpuCorr3dMM_gradWeights(border_mode=node.op.border_mode, rval = GpuCorr3dMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)( filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(
gpu_contiguous(img), gpu_contiguous(topgrad), shape) gpu_contiguous(img), gpu_contiguous(topgrad), shape)
if node.op.filter_flip: if node.op.filter_flip:
rval = rval[:, :, ::-1, ::-1, ::-1] rval = rval[:, :, ::-1, ::-1, ::-1]
...@@ -1976,7 +1980,8 @@ def local_abstractconv3d_gradinputs_gemm(node): ...@@ -1976,7 +1980,8 @@ def local_abstractconv3d_gradinputs_gemm(node):
rval = GpuCorr3dMM_gradInputs(border_mode=node.op.border_mode, rval = GpuCorr3dMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)( filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(
gpu_contiguous(kern), gpu_contiguous(topgrad), shape) gpu_contiguous(kern), gpu_contiguous(topgrad), shape)
return [rval] return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论