提交 73b55c61 authored 作者: carriepl's avatar carriepl

Merge pull request #2114 from ballasn/Corr3DMM

Add 3d correlation based on blas matrix multiplication
...@@ -123,6 +123,21 @@ TODO: Give examples on how to use these things! They are pretty complicated. ...@@ -123,6 +123,21 @@ TODO: Give examples on how to use these things! They are pretty complicated.
f = theano.function(..., mode=mode) f = theano.function(..., mode=mode)
- :func:`GpuCorr3dMM <theano.sandbox.cuda.blas.GpuCorr3dMM>`
This is a GPU-only 3d correlation relying on a Toeplitz matrix
and gemm implementation (see :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`)
It needs extra memory for the Toeplitz matrix, which is a 2D matrix of shape
``(no of channels * filter width * filter height * filter depth, output width * output height * output depth)``.
As it provides a gradient, you can use it as a replacement for nnet.conv3d.
Alternatively, you can use nnet.conv3d and allow Theano's graph optimizer
to replace it by the GEMM version by setting
``THEANO_FLAGS=optimizer_including=conv3d_gemm:convgrad3d_gemm:convtransp3d_gemm`` in your environment.
This is not enabled by default because it uses some extra memory, but the
overhead is small compared to conv3d_fft, there are no restrictions on
input or kernel shapes and strides are supported. If using it,
please see the warning about a bug in CUDA 5.0 to 6.0
in :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`.
- :func:`conv3d2d <theano.tensor.nnet.conv3d2d.conv3d>` - :func:`conv3d2d <theano.tensor.nnet.conv3d2d.conv3d>`
Another conv3d implementation that uses the conv2d with data reshaping. Another conv3d implementation that uses the conv2d with data reshaping.
It is faster in some cases than conv3d, and work on the GPU. It is faster in some cases than conv3d, and work on the GPU.
......
差异被折叠。
...@@ -294,6 +294,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -294,6 +294,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the " "This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n", "GpuCorrMM() documentation.\n",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_DECREF(col);
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
...@@ -311,6 +312,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -311,6 +312,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the " "This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n", "GpuCorrMM() documentation.\n",
cublasGetErrorString(status)); cublasGetErrorString(status));
Py_DECREF(col);
return NULL; return NULL;
} }
} }
...@@ -359,6 +361,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -359,6 +361,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the " "This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n", "GpuCorrMM() documentation.\n",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_DECREF(col);
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
...@@ -379,6 +382,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -379,6 +382,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the " "This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n", "GpuCorrMM() documentation.\n",
cublasGetErrorString(status)); cublasGetErrorString(status));
Py_DECREF(col);
return NULL; return NULL;
} }
} }
...@@ -429,6 +433,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -429,6 +433,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the " "This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n", "GpuCorrMM() documentation.\n",
cublasGetErrorString(status)); cublasGetErrorString(status));
Py_DECREF(col);
return NULL; return NULL;
} }
// col2im back to the data // col2im back to the data
...@@ -441,6 +446,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -441,6 +446,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the " "This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n", "GpuCorrMM() documentation.\n",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_DECREF(col);
return NULL; return NULL;
} }
} }
......
...@@ -26,7 +26,8 @@ from theano.sandbox.cuda.basic_ops import ( ...@@ -26,7 +26,8 @@ from theano.sandbox.cuda.basic_ops import (
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.blas import (gpu_dot22, gpu_dot22scalar, from theano.sandbox.cuda.blas import (gpu_dot22, gpu_dot22scalar,
gpu_gemm_inplace, gpu_gemm_no_inplace, GpuConv, gpu_gemm_inplace, gpu_gemm_no_inplace, GpuConv,
GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights) GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights,
GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights)
from theano.sandbox.cuda.blas import gpu_gemv_inplace from theano.sandbox.cuda.blas import gpu_gemv_inplace
from theano.sandbox.cuda.blas import gpu_gemv_no_inplace from theano.sandbox.cuda.blas import gpu_gemv_no_inplace
from theano.sandbox.cuda.blas import gpu_ger_inplace from theano.sandbox.cuda.blas import gpu_ger_inplace
...@@ -1338,6 +1339,74 @@ def local_convtransp3d_fft(node): ...@@ -1338,6 +1339,74 @@ def local_convtransp3d_fft(node):
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft) gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
@local_optimizer([Conv3D])
def local_conv3d_gemm(node):
if not isinstance(node.op, Conv3D):
return
try:
sx = tensor.get_scalar_constant_value(node.inputs[3][0])
sy = tensor.get_scalar_constant_value(node.inputs[3][1])
sz = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, Conv3D):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0]
x = x.dimshuffle(0, 4, 1, 2, 3)
# Shuffle filters from (oc, 0, 1, t, ic) to (oc, ic, 0, 1, t)
f = node.inputs[1]
f = f.dimshuffle(0, 4, 1, 2, 3)
rval = GpuCorr3dMM(border_mode='valid', subsample=(sx, sy, sz))(x, f)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]]
gpu_optimizer.register("conv3d_gemm", local_conv3d_gemm)
@local_optimizer([ConvGrad3D])
def local_convgrad3d_gemm(node):
try:
sx = tensor.get_scalar_constant_value(node.inputs[1][0])
sy = tensor.get_scalar_constant_value(node.inputs[1][1])
sz = tensor.get_scalar_constant_value(node.inputs[1][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, ConvGrad3D):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0]
x = gpu_contiguous(x.dimshuffle(0, 4, 1, 2, 3))
# Shuffle dCdH from (b, 0, 1, t, oc) to (oc, b, 0, 1, t)
f = node.inputs[3]
f = gpu_contiguous(f.dimshuffle(0, 4, 1, 2, 3))
rval = GpuCorr3dMM_gradWeights(subsample=(sx, sy, sz))(x, f,
shape=node.inputs[2][1:4])
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1)]
gpu_optimizer.register("convgrad3d_gemm", local_convgrad3d_gemm)
@local_optimizer([ConvTransp3D])
def local_convtransp3d_gemm(node):
try:
sx = tensor.get_scalar_constant_value(node.inputs[2][0])
sy = tensor.get_scalar_constant_value(node.inputs[2][1])
sz = tensor.get_scalar_constant_value(node.inputs[2][2])
except tensor.NotScalarConstantError:
return False
if isinstance(node.op, ConvTransp3D) and (sx, sy, sz) == (1, 1, 1):
# Shuffle filters from (oc, 0, 1, t, ic) to (ic, oc, 0, 1, t)
x = node.inputs[0]
x = gpu_contiguous(x.dimshuffle(0, 4, 1, 2, 3))
# Shuffle dCdH from (b, 0, 1, t, oc) to (b, oc, 0, 1, t)
f = node.inputs[3]
f = gpu_contiguous(f.dimshuffle(0, 4, 1, 2, 3))
rval = GpuCorr3dMM_gradInputs(subsample=(sx, sy, sz))(kern=x, topgrad=f)
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[1]]
gpu_optimizer.register("convtransp3d_gemm", local_convtransp3d_gemm)
import theano.tensor.signal.downsample as downsample import theano.tensor.signal.downsample as downsample
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论