提交 8da3449a authored 作者: f0k's avatar f0k

Fix GpuDnnConvGradI not being inserted automatically

上级 8cb9d50e
......@@ -1115,7 +1115,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
If border_mode is 'valid', subsample is (1,1) and direction_hint is
'bprop weights', it will use GpuDnnConvGradW.
If border_mode is 'full', subsample is (1,1) and direction_hint is
*not* 'forward!', it will use GpuDnnConvGradI.
'bprop inputs', it will use GpuDnnConvGradI.
This parameter is used internally by graph optimizers and may be
removed at any time without a deprecation period. You have been warned.
:param workmem: *deprecated*, use param algo instead
......@@ -1138,7 +1138,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
algo = workmem
# Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward']
assert direction_hint in [None, 'bprop weights', 'bprop inputs', 'forward']
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
if (border_mode == 'valid' and subsample == (1, 1) and
......@@ -1161,12 +1161,10 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
return as_cuda_ndarray_variable(conv.dimshuffle(1, 0, 2, 3))
elif (border_mode == 'full' and subsample == (1, 1) and
direction_hint != 'forward!' and version() == -1):
# Special case: In CuDNN v1, we can be faster by using GpuDnnConvGradI
# to compute the full convolution as the backward pass of a valid
# convolution. We just need to set up a suitable 'fake' valid
# convolution.
img = gpu_contiguous(img) # cudnn v1 and v2 rc3 need contiguous data
direction_hint == 'bprop inputs'):
# Special case: We are asked to use GpuDnnConvGradI. We need to set
# up a suitable 'fake' convolution to compute the gradient for.
img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns.dimshuffle(1, 0, 2, 3))
conv_mode = 'cross' if conv_mode == 'conv' else 'conv'
shape2 = shape_i(img, 2, fgraph) + shape_i(kerns, 2, fgraph) - 1
......@@ -2049,8 +2047,11 @@ if True:
direction_hint = node.op.direction_hint
if border_mode == 'full':
# for a full convolution, try using the forward pass instead
# of the backward pass wrt. inputs
direction_hint = 'forward!'
# of the backward pass wrt. inputs and vice versa
if direction_hint == 'bprop inputs':
direction_hint = 'forward'
else:
direction_hint = 'bprop inputs'
elif border_mode == 'valid':
# for a valid convolution, try using the backward pass wrt.
# weights instead of the forward pass and vice versa
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论