提交 8b25d693 authored 作者: João Victor Risso's avatar João Victor Risso

Move local_abstractconv_cudnn_graph function after spatial transformer implementation

上级 525c21c2
...@@ -2738,101 +2738,6 @@ def dnn_batch_normalization_test(inputs, gamma, beta, mean, var, ...@@ -2738,101 +2738,6 @@ def dnn_batch_normalization_test(inputs, gamma, beta, mean, var,
return result return result
def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
if (not isinstance(op, (AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs))):
return
if version(raises=False) < 6000 and op.filter_dilation != (1, 1):
return None
inp1 = inputs[0]
inp2 = inputs[1]
if not dnn_available(inp1.type.context_name):
return
if op.filter_flip:
conv_mode = 'conv'
else:
conv_mode = 'cross'
if isinstance(op, AbstractConv2d):
rval = dnn_conv(inp1, inp2,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
direction_hint='forward!',
conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1],
inputs[2][0], inputs[2][1])
rval = dnn_gradweight(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1],
inputs[2][0], inputs[2][1])
rval = dnn_gradinput(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode,
num_groups=op.num_groups)
return [rval]
def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs):
if (not isinstance(op, (AbstractConv3d,
AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs))):
return
if version(raises=False) < 6000 and op.filter_dilation != (1, 1, 1):
return None
inp1 = inputs[0]
inp2 = inputs[1]
if not dnn_available(inp1.type.context_name):
return
if op.filter_flip:
conv_mode = 'conv'
else:
conv_mode = 'cross'
if isinstance(op, AbstractConv3d):
rval = dnn_conv3d(inp1, inp2,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
direction_hint='forward!',
conv_mode=conv_mode)
elif isinstance(op, AbstractConv3d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1],
inputs[2][0], inputs[2][1], inputs[2][2])
rval = dnn_gradweight3d(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode)
elif isinstance(op, AbstractConv3d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1],
inputs[2][0], inputs[2][1], inputs[2][2])
rval = dnn_gradinput3d(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode)
return [rval]
class GpuDnnTransformerDesc(COp): class GpuDnnTransformerDesc(COp):
""" """
This Op builds a spatial transformer descriptor for use in spatial transformer network This Op builds a spatial transformer descriptor for use in spatial transformer network
...@@ -3116,6 +3021,101 @@ def dnn_spatialtf(img, theta, scale_width=1, scale_height=1, precision=theano.co ...@@ -3116,6 +3021,101 @@ def dnn_spatialtf(img, theta, scale_width=1, scale_height=1, precision=theano.co
return sampler return sampler
def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
if (not isinstance(op, (AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs))):
return
if version(raises=False) < 6000 and op.filter_dilation != (1, 1):
return None
inp1 = inputs[0]
inp2 = inputs[1]
if not dnn_available(inp1.type.context_name):
return
if op.filter_flip:
conv_mode = 'conv'
else:
conv_mode = 'cross'
if isinstance(op, AbstractConv2d):
rval = dnn_conv(inp1, inp2,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
direction_hint='forward!',
conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1],
inputs[2][0], inputs[2][1])
rval = dnn_gradweight(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1],
inputs[2][0], inputs[2][1])
rval = dnn_gradinput(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode,
num_groups=op.num_groups)
return [rval]
def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs):
if (not isinstance(op, (AbstractConv3d,
AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs))):
return
if version(raises=False) < 6000 and op.filter_dilation != (1, 1, 1):
return None
inp1 = inputs[0]
inp2 = inputs[1]
if not dnn_available(inp1.type.context_name):
return
if op.filter_flip:
conv_mode = 'conv'
else:
conv_mode = 'cross'
if isinstance(op, AbstractConv3d):
rval = dnn_conv3d(inp1, inp2,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
direction_hint='forward!',
conv_mode=conv_mode)
elif isinstance(op, AbstractConv3d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1],
inputs[2][0], inputs[2][1], inputs[2][2])
rval = dnn_gradweight3d(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode)
elif isinstance(op, AbstractConv3d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1],
inputs[2][0], inputs[2][1], inputs[2][2])
rval = dnn_gradinput3d(inp1, inp2, shape,
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
conv_mode=conv_mode)
return [rval]
@local_optimizer([AbstractConv2d, AbstractConv3d]) @local_optimizer([AbstractConv2d, AbstractConv3d])
def local_abstractconv_cudnn(node): def local_abstractconv_cudnn(node):
ctx = infer_context_name(*node.inputs) ctx = infer_context_name(*node.inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论