提交 9cfec67e authored 作者: affanv14's avatar affanv14

add alternative implementation for gradweights

上级 3248ce74
...@@ -1759,6 +1759,37 @@ def local_abstractconv_gradweights_gemm(node): ...@@ -1759,6 +1759,37 @@ def local_abstractconv_gradweights_gemm(node):
return [rval] return [rval]
@local_optimizer([AbstractConv2d_gradWeights])
def local_abstractconv_gemm_gradweights_alt(node):
if not isinstance(node.op, AbstractConv2d_gradWeights):
return None
img, topgrad, shape = node.inputs
if not isinstance(img.type, GpuArrayType) or \
not isinstance(topgrad.type, GpuArrayType):
return None
ctx = infer_context_name(img, topgrad)
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
if border_mode == 'valid' and subsample == (1, 1) and filter_dilation == (1, 1):
rval = GpuCorrMM(border_mode,
subsample,
filter_dilation)(
gpu_contiguous(img.dimshuffle(1, 0, 2, 3)),
gpu_contiguous(topgrad.dimshuffle(1, 0, 2, 3)))
if node.op.filter_flip:
rval = rval[:, :, ::-1, ::-1]
rval = rval.dimshuffle(1, 0, 2, 3)
rval = tensor.patternbroadcast(rval, node.outputs[0].broadcastable)
rval = as_gpuarray_variable(rval, context_name=ctx)
return [rval]
else:
return None
@local_optimizer([AbstractConv3d_gradWeights]) @local_optimizer([AbstractConv3d_gradWeights])
def local_abstractconv3d_gradweights_gemm(node): def local_abstractconv3d_gradweights_gemm(node):
if not isinstance(node.op, AbstractConv3d_gradWeights): if not isinstance(node.op, AbstractConv3d_gradWeights):
...@@ -2511,6 +2542,7 @@ if config.optimizer_excluding: ...@@ -2511,6 +2542,7 @@ if config.optimizer_excluding:
conv_metaopt.register(abstractconv_groupopt.query(*running_list).opts) conv_metaopt.register(abstractconv_groupopt.query(*running_list).opts)
conv_metaopt.register([local_abstractconv_gemm_alternative]) conv_metaopt.register([local_abstractconv_gemm_alternative])
conv_metaopt.register([local_abstractconv_gemm_gradweights_alt])
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0) abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0)
# Register cuDNN batch normalization implementation # Register cuDNN batch normalization implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论