提交 a3192afc authored 作者: affanv14's avatar affanv14

add cudnn gradweights alternative

上级 b540b158
......@@ -3012,6 +3012,7 @@ def local_abstractconv3d_cudnn_alternative(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
precision = get_precision(None, [inp1, inp2])
if node.op.filter_flip:
conv_mode = 'conv'
......@@ -3034,7 +3035,41 @@ def local_abstractconv3d_cudnn_alternative(node):
dilation=filter_dilation,
direction_hint=direction_hint,
conv_mode=conv_mode)
return rval
if isinstance(op, AbstractConv3d_gradWeights):
if(border_mode == 'valid' and subsample == (1, 1, 1) and
filter_dilation == (1, 1, 1)):
img = gpu_contiguous(inp1)
topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(img, topgrad)
img = gpu_contiguous(img.dimshuffle(1, 0, 2, 3, 4))
topgrad = gpu_contiguous(topgrad.dimshuffle(1, 0, 2, 3, 4))
ishape = [shape_i_op(i)(img) for i in range(img.ndim)]
tshape = [shape_i_op(i)(topgrad) for i in range(topgrad.ndim)]
out_shp = get_conv_output_shape(ishape,
tshape,
border_mode=border_mode,
subsample=subsample,
filter_dilation=filter_dilation)
out_shp = assert_conv_shape(out_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
desc = GpuDnnConvDesc(border_mode=border_mode,
subsample=subsample,
dilation=filter_dilation,
conv_mode='cross',
precision=precision)(out.shape)
conv = GpuDnnConv(algo=None)(img, topgrad, out, desc)
if conv_mode == 'conv':
conv = conv[:, :, ::-1, ::-1, ::-1]
rval = as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3, 4), ctx_name)
else:
return None
return [rval]
@local_optimizer([AbstractConv2d_gradWeights, AbstractConv3d_gradWeights])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论