提交 65651b47 authored 作者: affanv14's avatar affanv14

added gradinputs

上级 76714ad0
...@@ -2888,39 +2888,71 @@ def local_abstractconv_cudnn(node): ...@@ -2888,39 +2888,71 @@ def local_abstractconv_cudnn(node):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs) return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
@local_optimizer([AbstractConv2d]) @local_optimizer([AbstractConv2d, AbstractConv2d_gradWeights, AbstractConv2d_gradInputs])
def local_abstractconv_cudnn_alternative(node): def local_abstractconv_cudnn_alternative(node):
if not isinstance(node.op, AbstractConv2d): if(not isinstance(node.op, (AbstractConv2d, AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs))):
return return
if version(raises=False) < 6000 and node.op.filter_dilation != (1, 1): if version(raises=False) < 6000 and node.op.filter_dilation != (1, 1):
return None return None
inp1 = node.inputs[0] inp1 = node.inputs[0]
inp2 = node.inputs[1] inp2 = node.inputs[1]
if not dnn_available(inp1.type.context_name): if not dnn_available(inp1.type.context_name):
return return
op = node.op
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
precision = get_precision(None, [inp1, inp2])
if node.op.filter_flip: if node.op.filter_flip:
conv_mode = 'conv' conv_mode = 'conv'
else: else:
conv_mode = 'cross' conv_mode = 'cross'
if node.op.border_mode == 'full': if isinstance(op, AbstractConv2d):
if border_mode == 'full':
direction_hint = 'bprop inputs' direction_hint = 'bprop inputs'
elif node.op.border_mode == 'valid': elif border_mode == 'valid':
direction_hint = 'bprop weights' direction_hint = 'bprop weights'
else: else:
return None return None
rval = dnn_conv(inp1, inp2, rval = dnn_conv(inp1, inp2,
border_mode=node.op.border_mode, border_mode=border_mode,
subsample=node.op.subsample, subsample=subsample,
dilation=node.op.filter_dilation, dilation=filter_dilation,
direction_hint=direction_hint, direction_hint=direction_hint,
conv_mode=conv_mode, conv_mode=conv_mode,
num_groups=node.op.num_groups) num_groups=num_groups)
if isinstance(op, AbstractConv2d_gradInputs):
if border_mode == 'valid' and subsample == (1, 1) and num_groups == 1:
kerns = gpu_contiguous(inp1.dimshuffle(1, 0, 2, 3))
topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(kerns, topgrad)
conv_mode = 'cross' if conv_mode == 'conv' else 'conv'
desc = GpuDnnConvDesc(border_mode='full',
subsample=subsample,
dilation=filter_dilation,
conv_mode=conv_mode,
precision=precision)(kerns.shape)
tshape = [shape_i_op(i)(topgrad) for i in range(topgrad.ndim)]
kshape = [shape_i_op(i)(kerns) for i in range(kerns.ndim)]
shape = get_conv_output_shape(tshape,
kshape,
border_mode='full',
subsample=subsample,
filter_dilation=filter_dilation)
shape = assert_conv_shape(shape)
out = GpuAllocEmpty(dtype=topgrad.dtype, context_name=ctx_name)(*shape)
rval = GpuDnnConv(algo=None, num_groups=num_groups)(topgrad, kerns, out, desc)
return [rval] return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论