提交 4982e94d authored 作者: affanv14's avatar affanv14

add checks for 3d convolutions alternative optimizers

上级 9592125c
...@@ -3266,6 +3266,7 @@ def local_abstractconv3d_cudnn_alt(node): ...@@ -3266,6 +3266,7 @@ def local_abstractconv3d_cudnn_alt(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
precision = get_precision(None, [inp1, inp2]) precision = get_precision(None, [inp1, inp2])
if node.op.filter_flip: if node.op.filter_flip:
...@@ -3274,7 +3275,7 @@ def local_abstractconv3d_cudnn_alt(node): ...@@ -3274,7 +3275,7 @@ def local_abstractconv3d_cudnn_alt(node):
conv_mode = 'cross' conv_mode = 'cross'
if isinstance(op, AbstractConv3d): if isinstance(op, AbstractConv3d):
if border_mode == 'half' or subsample != (1, 1, 1): if border_mode == 'half' or subsample != (1, 1, 1) or num_groups > 1:
return None return None
if border_mode == 'full': if border_mode == 'full':
direction_hint = 'bprop inputs' direction_hint = 'bprop inputs'
...@@ -3292,7 +3293,7 @@ def local_abstractconv3d_cudnn_alt(node): ...@@ -3292,7 +3293,7 @@ def local_abstractconv3d_cudnn_alt(node):
elif isinstance(op, AbstractConv3d_gradWeights): elif isinstance(op, AbstractConv3d_gradWeights):
if(border_mode == 'valid' and subsample == (1, 1, 1) and if(border_mode == 'valid' and subsample == (1, 1, 1) and
filter_dilation == (1, 1, 1)): filter_dilation == (1, 1, 1) and num_groups == 1):
img = gpu_contiguous(inp1) img = gpu_contiguous(inp1)
topgrad = gpu_contiguous(inp2) topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(img, topgrad) ctx_name = infer_context_name(img, topgrad)
...@@ -3323,7 +3324,7 @@ def local_abstractconv3d_cudnn_alt(node): ...@@ -3323,7 +3324,7 @@ def local_abstractconv3d_cudnn_alt(node):
return None return None
elif isinstance(op, AbstractConv3d_gradInputs): elif isinstance(op, AbstractConv3d_gradInputs):
if border_mode == 'valid' and subsample == (1, 1, 1): if border_mode == 'valid' and subsample == (1, 1, 1) and num_groups == 1:
kerns = gpu_contiguous(inp1.dimshuffle(1, 0, 2, 3, 4)) kerns = gpu_contiguous(inp1.dimshuffle(1, 0, 2, 3, 4))
topgrad = gpu_contiguous(inp2) topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(kerns, topgrad) ctx_name = infer_context_name(kerns, topgrad)
......
...@@ -1842,8 +1842,10 @@ def local_abstractconv3d_alt(node): ...@@ -1842,8 +1842,10 @@ def local_abstractconv3d_alt(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if ((border_mode == 'full') and (subsample == (1, 1, 1))): if((border_mode == 'full') and (subsample == (1, 1, 1)) and
(num_groups == 1)):
if not node.op.filter_flip: if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1, ::-1]
kern = kern.dimshuffle(1, 0, 2, 3, 4) kern = kern.dimshuffle(1, 0, 2, 3, 4)
...@@ -1853,7 +1855,7 @@ def local_abstractconv3d_alt(node): ...@@ -1853,7 +1855,7 @@ def local_abstractconv3d_alt(node):
gpu_contiguous(kern), gpu_contiguous(img)) gpu_contiguous(kern), gpu_contiguous(img))
elif(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and elif(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and
border_mode == 'valid'): border_mode == 'valid' and num_groups == 1):
if node.op.filter_flip: if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM_gradWeights(border_mode, rval = GpuCorr3dMM_gradWeights(border_mode,
...@@ -1881,8 +1883,10 @@ def local_abstractconv3d2d(node): ...@@ -1881,8 +1883,10 @@ def local_abstractconv3d2d(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if subsample == (1, 1, 1) and filter_dilation == (1, 1, 1): if(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and
num_groups == 1):
reorder_array = [0, 2, 1, 3, 4] reorder_array = [0, 2, 1, 3, 4]
rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(*reorder_array)), rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(*reorder_array)),
gpu_contiguous(kern.dimshuffle(*reorder_array)), gpu_contiguous(kern.dimshuffle(*reorder_array)),
...@@ -1968,8 +1972,10 @@ def local_abstractconv3d_gemm_gradweights_alt(node): ...@@ -1968,8 +1972,10 @@ def local_abstractconv3d_gemm_gradweights_alt(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1, 1) and filter_dilation == (1, 1, 1): if(border_mode == 'valid' and subsample == (1, 1, 1) and
filter_dilation == (1, 1, 1) and num_groups == 1):
rval = GpuCorr3dMM(border_mode, rval = GpuCorr3dMM(border_mode,
subsample, subsample,
filter_dilation)( filter_dilation)(
...@@ -2091,8 +2097,10 @@ def local_abstractconv3d_gradinputs_gemm_alt(node): ...@@ -2091,8 +2097,10 @@ def local_abstractconv3d_gradinputs_gemm_alt(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1, 1): if(border_mode == 'valid' and subsample == (1, 1, 1) and
num_groups == 1):
if not node.op.filter_flip: if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM(border_mode='full', rval = GpuCorr3dMM(border_mode='full',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论