提交 289c3bd4 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Introduce AbstractConv3D and related changes.

Add abstract convolution classes, reuse this for 2D and 3D.
上级 d3fb7189
......@@ -87,7 +87,7 @@ from theano.tensor import slinalg
from theano.tensor.nnet.Conv3D import Conv3D
from theano.tests.breakpoint import PdbBreakpoint
from theano.tensor.nnet.abstract_conv import (BaseAbstractConv2d,
from theano.tensor.nnet.abstract_conv import (BaseAbstractConv,
AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs)
......@@ -2736,7 +2736,7 @@ def local_conv2d_gpu_conv(node):
if isinstance(node.op, GpuFromHost):
host_input = node.inputs[0]
if host_input.owner and isinstance(host_input.owner.op,
BaseAbstractConv2d):
BaseAbstractConv):
conv = host_input.owner.op
inps = list(host_input.owner.inputs)
......@@ -2749,7 +2749,7 @@ def local_conv2d_gpu_conv(node):
out.tag.values_eq_approx = values_eq_approx_high_tol
return [out]
if isinstance(node.op, BaseAbstractConv2d):
if isinstance(node.op, BaseAbstractConv):
# conv(host_from_gpu) -> host_from_gpu(gpu_conv)
inp1 = node.inputs[0]
inp2 = node.inputs[1]
......
......@@ -32,6 +32,7 @@ from .bn import batch_normalization
import warnings
from .abstract_conv import conv2d as abstract_conv2d
from .abstract_conv import conv3d as abstract_conv3d
def conv2d(input, filters, input_shape=None, filter_shape=None,
......
......@@ -18,6 +18,9 @@ from theano.tensor.nnet.blocksparse import (
from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs)
from theano.tensor.nnet.abstract_conv import (AbstractConv3d,
AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs)
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.opt import register_specialize_device
from theano.tensor import TensorType
......@@ -25,6 +28,7 @@ from theano.tensor import opt
# Cpu implementation
from theano.tensor.nnet.conv import conv2d, ConvOp
from theano.tensor.nnet.Conv3D import conv3D
from theano.tensor.nnet.ConvGrad3D import convGrad3D
from theano.tensor.nnet.ConvTransp3D import convTransp3D
......@@ -159,6 +163,37 @@ def local_conv2d_cpu(node):
return [rval]
@local_optimizer([AbstractConv3d])
def local_conv3d_cpu(node):
if not isinstance(node.op, AbstractConv3d):
return None
img, kern = node.inputs
if ((not isinstance(img.type, TensorType) or
not isinstance(kern.type, TensorType))):
return None
if node.op.border_mode not in ['valid', (0, 0, 0)]:
return None
if node.op.filter_dilation != (1, 1, 1):
return None
bias = theano.tensor.zeros_like(kern[:, 0, 0, 0, 0])
# need to flip the kernel if necessary (conv3D does not flip)
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
# conv3D expects shape (batch, row, column, time, channel)
img = img.dimshuffle(0, 2, 3, 4, 1)
kern = kern.dimshuffle(0, 2, 3, 4, 1)
rval = conv3D(img, kern, bias, node.op.subsample)
copy_stack_trace(node.outputs[0], rval)
rval = rval.dimshuffle(0, 4, 1, 2, 3)
return [rval]
@local_optimizer([AbstractConv2d_gradWeights])
def local_conv2d_gradweight_cpu(node):
if not isinstance(node.op, AbstractConv2d_gradWeights):
......@@ -277,6 +312,39 @@ def local_conv2d_gradweight_cpu(node):
return [res]
@local_optimizer([AbstractConv3d_gradWeights])
def local_conv3d_gradweight_cpu(node):
if not isinstance(node.op, AbstractConv3d_gradWeights):
return None
img, topgrad, shape = node.inputs
if ((not isinstance(img.type, TensorType) or
not isinstance(topgrad.type, TensorType))):
return None
if node.op.border_mode not in ['valid', (0, 0, 0)]:
return None
if node.op.filter_dilation != (1, 1, 1):
return None
# conv3D expects shape (batch, row, column, time, channel)
img = img.dimshuffle(0, 2, 3, 4, 1)
topgrad = topgrad.dimshuffle(0, 2, 3, 4, 1)
W_shape = (topgrad.shape[4], shape[0], shape[1], shape[2], img.shape[4])
rval = convGrad3D(img, node.op.subsample, W_shape, topgrad)
copy_stack_trace(node.outputs[0], rval)
rval = rval.dimshuffle(0, 4, 1, 2, 3)
# need to flip the kernel if necessary (conv3D does not flip)
if node.op.filter_flip:
rval = rval[:, :, ::-1, ::-1, ::-1]
rval = theano.tensor.patternbroadcast(rval,
node.outputs[0].broadcastable)
return [rval]
@local_optimizer([AbstractConv2d_gradInputs])
def local_conv2d_gradinputs_cpu(node):
if not isinstance(node.op, AbstractConv2d_gradInputs):
......@@ -366,6 +434,38 @@ def local_conv2d_gradinputs_cpu(node):
return [din]
@local_optimizer([AbstractConv3d_gradInputs])
def local_conv3d_gradinputs_cpu(node):
if not isinstance(node.op, AbstractConv3d_gradInputs):
return None
kern, topgrad, shape = node.inputs
if ((not isinstance(kern.type, TensorType) or
not isinstance(topgrad.type, TensorType))):
return None
if node.op.border_mode not in ['valid', (0, 0, 0)]:
return None
if node.op.filter_dilation != (1, 1, 1):
return None
# need to flip the kernel if necessary (conv3D does not flip)
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
# conv3D expects shape (batch, row, column, time, channel)
kern = kern.dimshuffle(0, 2, 3, 4, 1)
topgrad = topgrad.dimshuffle(0, 2, 3, 4, 1)
bias = theano.tensor.zeros_like(kern[0, 0, 0, 0, :])
rval = convTransp3D(kern, bias, node.op.subsample, topgrad, shape)
copy_stack_trace(node.outputs[0], rval)
rval = rval.dimshuffle(0, 4, 1, 2, 3)
rval = theano.tensor.patternbroadcast(rval,
node.outputs[0].broadcastable)
return [rval]
# Register Cpu Optmization
conv_groupopt = theano.gof.optdb.LocalGroupDB()
conv_groupopt.__name__ = "conv_opts"
......@@ -390,16 +490,30 @@ conv_groupopt.register('local_conv2d_gradweight_cpu',
conv_groupopt.register('local_conv2d_gradinputs_cpu',
local_conv2d_gradinputs_cpu, 40,
'fast_compile', 'fast_run')
conv_groupopt.register('local_conv3d_cpu', local_conv3d_cpu, 40,
'fast_compile', 'fast_run')
conv_groupopt.register('local_conv3d_gradweight_cpu',
local_conv3d_gradweight_cpu, 40,
'fast_compile', 'fast_run')
conv_groupopt.register('local_conv3d_gradinputs_cpu',
local_conv3d_gradinputs_cpu, 40,
'fast_compile', 'fast_run')
# Verify that no AbstractConv are present in the graph
@local_optimizer([AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs])
AbstractConv2d_gradInputs,
AbstractConv3d,
AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs])
def local_abstractconv_check(node):
if isinstance(node.op, (AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs)):
AbstractConv2d_gradInputs,
AbstractConv3d,
AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs)):
raise AssertionError(
'%s Theano optimization failed: there is no implementation '
'available supporting the requested options. Did you exclude '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论