提交 5655dee6 authored 作者: affanv14's avatar affanv14

add forward optimizers for meta-optimizer

上级 d30f18dd
...@@ -1653,6 +1653,29 @@ def local_abstractconv_gemm(node): ...@@ -1653,6 +1653,29 @@ def local_abstractconv_gemm(node):
return [rval] return [rval]
# CorrMM opt used for Meta-optimizer
@local_optimizer([AbstractConv2d])
def local_abstractconv_gemm_def(node):
if not isinstance(node.op, AbstractConv2d):
return None
img, kern = node.inputs
if (not isinstance(img.type, GpuArrayType) or
not isinstance(kern.type, GpuArrayType)):
return None
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1]
rval = GpuCorrMM(border_mode,
subsample,
filter_dilation,
node.op.num_groups)(gpu_contiguous(img),
gpu_contiguous(kern))
return [rval]
@local_optimizer([AbstractConv2d]) @local_optimizer([AbstractConv2d])
def local_abstractconv_gemm_alt(node): def local_abstractconv_gemm_alt(node):
if not isinstance(node.op, AbstractConv2d): if not isinstance(node.op, AbstractConv2d):
...@@ -1768,6 +1791,30 @@ def local_abstractconv3d_gemm(node): ...@@ -1768,6 +1791,30 @@ def local_abstractconv3d_gemm(node):
return [rval] return [rval]
# Corr3dMM opt used for Meta-optimizer
@local_optimizer([AbstractConv3d])
def local_abstractconv3d_gemm_def(node):
if not isinstance(node.op, AbstractConv3d):
return None
img, kern = node.inputs
if (not isinstance(img.type, GpuArrayType) or
not isinstance(kern.type, GpuArrayType)):
return None
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
# By default use GpuCorr3dMM
rval = GpuCorr3dMM(border_mode,
subsample,
filter_dilation,
node.op.num_groups)(gpu_contiguous(img),
gpu_contiguous(kern))
return [rval]
@local_optimizer([AbstractConv3d]) @local_optimizer([AbstractConv3d])
def local_abstractconv3d_alt(node): def local_abstractconv3d_alt(node):
if not isinstance(node.op, AbstractConv3d): if not isinstance(node.op, AbstractConv3d):
...@@ -2745,9 +2792,9 @@ conv_metaopt.register(local_abstractconv_gw_cudnn, ...@@ -2745,9 +2792,9 @@ conv_metaopt.register(local_abstractconv_gw_cudnn,
['default', 'cudnn', 'conv_dnn']) ['default', 'cudnn', 'conv_dnn'])
conv_metaopt.register(local_abstractconv_gi_cudnn, conv_metaopt.register(local_abstractconv_gi_cudnn,
['default', 'cudnn', 'conv_dnn']) ['default', 'cudnn', 'conv_dnn'])
conv_metaopt.register(local_abstractconv_gemm, conv_metaopt.register(local_abstractconv_gemm_def,
['default', 'conv_gemm']) ['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv3d_gemm, conv_metaopt.register(local_abstractconv3d_gemm_def,
['default', 'conv_gemm']) ['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv_gradweights_gemm, conv_metaopt.register(local_abstractconv_gradweights_gemm,
['default', 'conv_gemm']) ['default', 'conv_gemm'])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论