提交 18ef99c3 authored 作者: Frederic's avatar Frederic

pep8

上级 3180ec4d
...@@ -2440,8 +2440,10 @@ if True: ...@@ -2440,8 +2440,10 @@ if True:
) )
return [out.dimshuffle(0, 1)] return [out.dimshuffle(0, 1)]
### AbstractConv Optimizations
@local_optimizer([AbstractConv2d, AbstractConv2d_gradWeights, AbstractConv2d_gradInputs]) # AbstractConv Optimizations
@local_optimizer([AbstractConv2d, AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs])
def local_abstractconv_cudnn(node): def local_abstractconv_cudnn(node):
inp1 = node.inputs[0] inp1 = node.inputs[0]
inp2 = node.inputs[1] inp2 = node.inputs[1]
...@@ -2466,20 +2468,21 @@ def local_abstractconv_cudnn(node): ...@@ -2466,20 +2468,21 @@ def local_abstractconv_cudnn(node):
border_mode=node.op.border_mode, border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
direction_hint='forward', direction_hint='forward',
conv_mode = conv_mode) conv_mode=conv_mode)
return [rval] return [rval]
if (isinstance(node.op, AbstractConv2d_gradWeights)): if (isinstance(node.op, AbstractConv2d_gradWeights)):
shape = (inp2.shape[1], inp1.shape[1], node.inputs[2][0], node.inputs[2][1]) shape = (inp2.shape[1], inp1.shape[1],
node.inputs[2][0], node.inputs[2][1])
rval = dnn_gradweight(inp1, inp2, shape, rval = dnn_gradweight(inp1, inp2, shape,
border_mode=node.op.border_mode, border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
conv_mode = conv_mode) conv_mode=conv_mode)
return [rval] return [rval]
if (isinstance(node.op, AbstractConv2d_gradInputs)): if (isinstance(node.op, AbstractConv2d_gradInputs)):
shape = (inp2.shape[0], inp1.shape[1], node.inputs[2][0], node.inputs[2][1]) shape = (inp2.shape[0], inp1.shape[1],
node.inputs[2][0], node.inputs[2][1])
rval = dnn_gradinput(inp1, inp2, shape, rval = dnn_gradinput(inp1, inp2, shape,
border_mode=node.op.border_mode, border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
conv_mode = conv_mode) conv_mode=conv_mode)
return [rval] return [rval]
...@@ -2788,7 +2788,8 @@ abstractconv_groupopt.__name__ = "gpu_abstractconv_opts" ...@@ -2788,7 +2788,8 @@ abstractconv_groupopt.__name__ = "gpu_abstractconv_opts"
register_specialize_device(abstractconv_groupopt, 'gpu', 'fast_compile') register_specialize_device(abstractconv_groupopt, 'gpu', 'fast_compile')
# cuDNN is first, but only registered if cuDNN is available. # cuDNN is first, but only registered if cuDNN is available.
conv_groupopt.register('local_abstractconv_dnn', dnn.local_abstractconv_cudnn, 20, conv_groupopt.register('local_abstractconv_dnn',
dnn.local_abstractconv_cudnn, 20,
'conv_dnn', 'conv_dnn',
'gpu', 'fast_compile', 'fast_run', 'cudnn') 'gpu', 'fast_compile', 'fast_run', 'cudnn')
# The GEMM-based convolution comes last to catch all remaining cases. # The GEMM-based convolution comes last to catch all remaining cases.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论