提交 affc0f7f authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Add optimization for the GpuCorr3dMM

上级 ae70df06
......@@ -26,7 +26,8 @@ from theano.sandbox.cuda.basic_ops import (
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.blas import (gpu_dot22, gpu_dot22scalar,
gpu_gemm_inplace, gpu_gemm_no_inplace, GpuConv,
GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights)
GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights,
GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights)
from theano.sandbox.cuda.blas import gpu_gemv_inplace
from theano.sandbox.cuda.blas import gpu_gemv_no_inplace
from theano.sandbox.cuda.blas import gpu_ger_inplace
......@@ -1338,6 +1339,77 @@ def local_convtransp3d_fft(node):
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
@local_optimizer([Conv3D])
def local_conv3d_gemm(node):
if not isinstance(node.op, Conv3D):
return
try:
sx = tensor.get_scalar_constant_value(node.inputs[3][0])
sy = tensor.get_scalar_constant_value(node.inputs[3][1])
sz = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, Conv3D):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0]
x = x.dimshuffle(0, 4, 1, 2, 3)
# Shuffle filters from (oc, 0, 1, t, ic) to (oc, ic, 0, 1, t)
f = node.inputs[1]
f = f.dimshuffle(0, 4, 1, 2, 3)
rval = GpuCorr3dMM(border_mode='valid', subsample=(sx, sy, sz))(x, f)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]]
gpu_optimizer.register("conv3d_gemm", local_conv3d_gemm)
@local_optimizer([ConvGrad3D])
def local_convgrad3d_gemm(node):
try:
sx = tensor.get_scalar_constant_value(node.inputs[1][0])
sy = tensor.get_scalar_constant_value(node.inputs[1][1])
sz = tensor.get_scalar_constant_value(node.inputs[1][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, ConvGrad3D):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0]
x = gpu_contiguous(x.dimshuffle(0, 4, 1, 2, 3))
# Shuffle dCdH from (b, 0, 1, t, oc) to (oc, b, 0, 1, t)
f = node.input[3]
f = gpu_contiguous(f.dimshuffle(0, 4, 1, 2, 3))
f = node.inputs[3]
f = f.dimshuffle(4, 0, 1, 2, 3)
rval = Gpucorr3dMM_gradWeights(subsample=(sx, sy, sz))(x, f,
shape=node.inputs[2])
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1)]
gpu_optimizer.register("convgrad3d_gemm", local_convgrad3d_gemm)
@local_optimizer([ConvTransp3D])
def local_convtransp3d_gemm(node):
try:
sx = tensor.get_scalar_constant_value(node.inputs[2][0])
sy = tensor.get_scalar_constant_value(node.inputs[2][1])
sz = tensor.get_scalar_constant_value(node.inputs[2][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, ConvTransp3D) and (sx, sy, sz) == (1, 1, 1):
# Shuffle filters from (oc, 0, 1, t, ic) to (ic, oc, 0, 1, t)
x = node.inputs[0]
x = gpu_contiguous(x.dimshuffle(0, 4, 1, 2, 3))
# Shuffle dCdH from (b, 0, 1, t, oc) to (b, oc, 0, 1, t)
f = node.inputs[3]
f = gpu_contiguous(f.dimshuffle(0, 4, 1, 2, 3))
# filter flip
rval = GpuCorr3DMM(border_mode='full', subsample=(sx, sy, sz))(f, x)
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[1]]
gpu_optimizer.register("convtransp3d_gemm", local_convtransp3d_gemm)
import theano.tensor.signal.downsample as downsample
......
......@@ -33,7 +33,7 @@ class TestCorr3DMM(unittest.TestCase):
b=bias, d=subsample)
conv = GpuCorr3dMM(border_mode = "valid",
subsample=subsample)(inputs.dimshuffle(0, 4, 1, 2, 3),
filters.dimshuffle(0, 4, 1, 2, 3))
filters.dimshuffle(0, 4, 1, 2, 3))
conv = conv.dimshuffle(0, 2, 3, 4, 1)
f_ref = theano.function([], conv_ref)
......@@ -67,7 +67,6 @@ class TestCorr3DMM(unittest.TestCase):
def run_gradweight(self, inputs_shape, filters_shape, dCdH_shape,
subsample=(1, 1, 1)):
inputs_val = numpy.random.random(inputs_shape).astype('float32')
dCdH_val = numpy.random.random(dCdH_shape).astype('float32')
filters_val = numpy.random.random(filters_shape).astype('float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论