提交 73b55c61 authored 作者: carriepl's avatar carriepl

Merge pull request #2114 from ballasn/Corr3DMM

Add 3d correlation based on blas matrix multiplication
......@@ -123,6 +123,21 @@ TODO: Give examples on how to use these things! They are pretty complicated.
f = theano.function(..., mode=mode)
- :func:`GpuCorr3dMM <theano.sandbox.cuda.blas.GpuCorr3dMM>`
This is a GPU-only 3d correlation relying on a Toeplitz matrix
and gemm implementation (see :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`)
It needs extra memory for the Toeplitz matrix, which is a 2D matrix of shape
``(no of channels * filter width * filter height * filter depth, output width * output height * output depth)``.
As it provides a gradient, you can use it as a replacement for nnet.conv3d.
Alternatively, you can use nnet.conv3d and allow Theano's graph optimizer
to replace it by the GEMM version by setting
``THEANO_FLAGS=optimizer_including=conv3d_gemm:convgrad3d_gemm:convtransp3d_gemm`` in your environment.
This is not enabled by default because it uses some extra memory, but the
overhead is small compared to conv3d_fft, there are no restrictions on
input or kernel shapes and strides are supported. If using it,
please see the warning about a bug in CUDA 5.0 to 6.0
in :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`.
- :func:`conv3d2d <theano.tensor.nnet.conv3d2d.conv3d>`
Another conv3d implementation that uses the conv2d with data reshaping.
It is faster in some cases than conv3d, and work on the GPU.
......
差异被折叠。
......@@ -294,6 +294,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
// Second, gemm
......@@ -311,6 +312,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
Py_DECREF(col);
return NULL;
}
}
......@@ -359,6 +361,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
// Second, gemm
......@@ -379,6 +382,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
Py_DECREF(col);
return NULL;
}
}
......@@ -429,6 +433,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
Py_DECREF(col);
return NULL;
}
// col2im back to the data
......@@ -441,6 +446,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
}
......
......@@ -26,7 +26,8 @@ from theano.sandbox.cuda.basic_ops import (
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.blas import (gpu_dot22, gpu_dot22scalar,
gpu_gemm_inplace, gpu_gemm_no_inplace, GpuConv,
GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights)
GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights,
GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights)
from theano.sandbox.cuda.blas import gpu_gemv_inplace
from theano.sandbox.cuda.blas import gpu_gemv_no_inplace
from theano.sandbox.cuda.blas import gpu_ger_inplace
......@@ -1338,6 +1339,74 @@ def local_convtransp3d_fft(node):
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
@local_optimizer([Conv3D])
def local_conv3d_gemm(node):
if not isinstance(node.op, Conv3D):
return
try:
sx = tensor.get_scalar_constant_value(node.inputs[3][0])
sy = tensor.get_scalar_constant_value(node.inputs[3][1])
sz = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, Conv3D):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0]
x = x.dimshuffle(0, 4, 1, 2, 3)
# Shuffle filters from (oc, 0, 1, t, ic) to (oc, ic, 0, 1, t)
f = node.inputs[1]
f = f.dimshuffle(0, 4, 1, 2, 3)
rval = GpuCorr3dMM(border_mode='valid', subsample=(sx, sy, sz))(x, f)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]]
gpu_optimizer.register("conv3d_gemm", local_conv3d_gemm)
@local_optimizer([ConvGrad3D])
def local_convgrad3d_gemm(node):
try:
sx = tensor.get_scalar_constant_value(node.inputs[1][0])
sy = tensor.get_scalar_constant_value(node.inputs[1][1])
sz = tensor.get_scalar_constant_value(node.inputs[1][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, ConvGrad3D):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0]
x = gpu_contiguous(x.dimshuffle(0, 4, 1, 2, 3))
# Shuffle dCdH from (b, 0, 1, t, oc) to (oc, b, 0, 1, t)
f = node.inputs[3]
f = gpu_contiguous(f.dimshuffle(0, 4, 1, 2, 3))
rval = GpuCorr3dMM_gradWeights(subsample=(sx, sy, sz))(x, f,
shape=node.inputs[2][1:4])
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1)]
gpu_optimizer.register("convgrad3d_gemm", local_convgrad3d_gemm)
@local_optimizer([ConvTransp3D])
def local_convtransp3d_gemm(node):
try:
sx = tensor.get_scalar_constant_value(node.inputs[2][0])
sy = tensor.get_scalar_constant_value(node.inputs[2][1])
sz = tensor.get_scalar_constant_value(node.inputs[2][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, ConvTransp3D) and (sx, sy, sz) == (1, 1, 1):
# Shuffle filters from (oc, 0, 1, t, ic) to (ic, oc, 0, 1, t)
x = node.inputs[0]
x = gpu_contiguous(x.dimshuffle(0, 4, 1, 2, 3))
# Shuffle dCdH from (b, 0, 1, t, oc) to (b, oc, 0, 1, t)
f = node.inputs[3]
f = gpu_contiguous(f.dimshuffle(0, 4, 1, 2, 3))
rval = GpuCorr3dMM_gradInputs(subsample=(sx, sy, sz))(kern=x, topgrad=f)
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[1]]
gpu_optimizer.register("convtransp3d_gemm", local_convtransp3d_gemm)
import theano.tensor.signal.downsample as downsample
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论