提交 0f4eb0e1 authored 作者: affanv14's avatar affanv14

add cudnn gradinputs alternative

上级 a3192afc
......@@ -3068,6 +3068,31 @@ def local_abstractconv3d_cudnn_alternative(node):
else:
return None
if isinstance(op, AbstractConv3d_gradInputs):
if border_mode == 'valid' and subsample == (1, 1, 1):
kerns = gpu_contiguous(inp1.dimshuffle(1, 0, 2, 3, 4))
topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(kerns, topgrad)
conv_mode = 'cross' if conv_mode == 'conv' else 'conv'
desc = GpuDnnConvDesc(border_mode='full',
subsample=subsample,
dilation=filter_dilation,
conv_mode=conv_mode,
precision=precision)(kerns.shape)
tshape = [shape_i_op(i)(topgrad) for i in range(topgrad.ndim)]
kshape = [shape_i_op(i)(kerns) for i in range(kerns.ndim)]
shape = get_conv_output_shape(tshape,
kshape,
border_mode='full',
subsample=subsample,
filter_dilation=filter_dilation)
shape = assert_conv_shape(shape)
out = GpuAllocEmpty(dtype=topgrad.dtype, context_name=ctx_name)(*shape)
rval = GpuDnnConv(algo=None)(topgrad, kerns, out, desc)
else:
return None
return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论