Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
73b55c61
提交
73b55c61
authored
10月 07, 2014
作者:
carriepl
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2114 from ballasn/Corr3DMM
Add 3d correlation based on blas matrix multiplication
上级
31a6c527
5f50150f
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
1322 行增加
和
4 行删除
+1322
-4
conv.txt
doc/library/tensor/nnet/conv.txt
+15
-0
blas.py
theano/sandbox/cuda/blas.py
+507
-3
corr3d_gemm.cu
theano/sandbox/cuda/corr3d_gemm.cu
+486
-0
corr_gemm.cu
theano/sandbox/cuda/corr_gemm.cu
+6
-0
opt.py
theano/sandbox/cuda/opt.py
+70
-1
test_gemmcorr3d.py
theano/sandbox/cuda/tests/test_gemmcorr3d.py
+238
-0
没有找到文件。
doc/library/tensor/nnet/conv.txt
浏览文件 @
73b55c61
...
@@ -123,6 +123,21 @@ TODO: Give examples on how to use these things! They are pretty complicated.
...
@@ -123,6 +123,21 @@ TODO: Give examples on how to use these things! They are pretty complicated.
f = theano.function(..., mode=mode)
f = theano.function(..., mode=mode)
- :func:`GpuCorr3dMM <theano.sandbox.cuda.blas.GpuCorr3dMM>`
This is a GPU-only 3d correlation relying on a Toeplitz matrix
and gemm implementation (see :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`)
It needs extra memory for the Toeplitz matrix, which is a 2D matrix of shape
``(no of channels * filter width * filter height * filter depth, output width * output height * output depth)``.
As it provides a gradient, you can use it as a replacement for nnet.conv3d.
Alternatively, you can use nnet.conv3d and allow Theano's graph optimizer
to replace it by the GEMM version by setting
``THEANO_FLAGS=optimizer_including=conv3d_gemm:convgrad3d_gemm:convtransp3d_gemm`` in your environment.
This is not enabled by default because it uses some extra memory, but the
overhead is small compared to conv3d_fft, there are no restrictions on
input or kernel shapes and strides are supported. If using it,
please see the warning about a bug in CUDA 5.0 to 6.0
in :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`.
- :func:`conv3d2d <theano.tensor.nnet.conv3d2d.conv3d>`
- :func:`conv3d2d <theano.tensor.nnet.conv3d2d.conv3d>`
Another conv3d implementation that uses the conv2d with data reshaping.
Another conv3d implementation that uses the conv2d with data reshaping.
It is faster in some cases than conv3d, and work on the GPU.
It is faster in some cases than conv3d, and work on the GPU.
...
...
theano/sandbox/cuda/blas.py
浏览文件 @
73b55c61
...
@@ -9,6 +9,7 @@ from theano.sandbox.cuda.type import CudaNdarrayType
...
@@ -9,6 +9,7 @@ from theano.sandbox.cuda.type import CudaNdarrayType
from
theano.sandbox.cuda
import
GpuOp
from
theano.sandbox.cuda
import
GpuOp
from
theano.sandbox.cuda.basic_ops
import
(
as_cuda_ndarray_variable
,
from
theano.sandbox.cuda.basic_ops
import
(
as_cuda_ndarray_variable
,
gpu_contiguous
)
gpu_contiguous
)
from
theano.tensor
import
as_tensor_variable
class
GpuDot22
(
GpuOp
):
class
GpuDot22
(
GpuOp
):
...
@@ -525,8 +526,6 @@ class BaseGpuCorrMM(GpuOp):
...
@@ -525,8 +526,6 @@ class BaseGpuCorrMM(GpuOp):
and
self
.
pad
==
other
.
pad
and
self
.
pad
==
other
.
pad
def
__hash__
(
self
):
def
__hash__
(
self
):
# don't use hash(self.version) as hash(-1)==-2 and
# hash(-2)==-2 in python!
return
hash
(
type
(
self
))
\
return
hash
(
type
(
self
))
\
^
hash
(
self
.
border_mode
)
\
^
hash
(
self
.
border_mode
)
\
^
hash
(
self
.
subsample
)
\
^
hash
(
self
.
subsample
)
\
...
@@ -564,7 +563,7 @@ class BaseGpuCorrMM(GpuOp):
...
@@ -564,7 +563,7 @@ class BaseGpuCorrMM(GpuOp):
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of
# REMEMBER TO RAISE c_code_cache_version when changing any of
# these files
# these files
files
=
[
'co
nv
_gemm.cu'
]
files
=
[
'co
rr
_gemm.cu'
]
codes
=
[
open
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
f
))
.
read
()
codes
=
[
open
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
f
))
.
read
()
for
f
in
files
]
for
f
in
files
]
return
reduce
(
str
.
__add__
,
codes
)
return
reduce
(
str
.
__add__
,
codes
)
...
@@ -960,6 +959,511 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
...
@@ -960,6 +959,511 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
return
[[
1
],
[
1
],
[
0
],
[
0
]]
# no connection to height, width
return
[[
1
],
[
1
],
[
0
],
[
0
]]
# no connection to height, width
class
BaseGpuCorr3dMM
(
GpuOp
):
"""Base class for `GpuCorr3dMM`, `GpuCorr3dMM_gradWeights` and
`GpuCorr3dMM_gradInputs`. Cannot be used directly."""
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
pad
=
(
0
,
0
,
0
)):
if
border_mode
!=
"valid"
:
raise
ValueError
(
"border_mode must be 'valid'"
)
self
.
border_mode
=
border_mode
if
len
(
subsample
)
!=
3
:
raise
ValueError
(
"subsample must have three elements"
)
self
.
subsample
=
subsample
if
(
pad
not
in
(
"half"
,
"full"
))
and
(
len
(
pad
)
!=
3
):
raise
ValueError
(
"pad must be 'half', 'full', or have three elements"
)
self
.
pad
=
pad
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
\
and
self
.
border_mode
==
other
.
border_mode
\
and
self
.
subsample
==
other
.
subsample
\
and
self
.
pad
==
other
.
pad
def
__hash__
(
self
):
return
hash
(
type
(
self
))
\
^
hash
(
self
.
border_mode
)
\
^
hash
(
self
.
subsample
)
\
^
hash
(
self
.
pad
)
def
__str__
(
self
):
return
'
%
s{
%
s,
%
s, pad=
%
r}'
%
(
self
.
__class__
.
__name__
,
self
.
border_mode
,
str
(
self
.
subsample
),
self
.
pad
)
def
flops
(
self
,
inp
,
outp
):
""" Useful with the hack in profilemode to print the MFlops"""
# if the output shape is correct, then this gives the correct
# flops for any direction, sampling, padding, and border mode
inputs
,
filters
=
inp
outputs
,
=
outp
assert
inputs
[
1
]
==
filters
[
1
]
# nb mul and add by output pixel
flops
=
filters
[
2
]
*
filters
[
3
]
*
filters
[
4
]
*
2
# nb flops by output image
flops
*=
outputs
[
2
]
*
outputs
[
3
]
*
outputs
[
4
]
# nb patch multiplied
flops
*=
inputs
[
1
]
*
filters
[
0
]
*
inputs
[
0
]
return
flops
def
c_headers
(
self
):
return
[
'cuda_ndarray.cuh'
,
'<stdio.h>'
]
def
c_code_cache_version
(
self
):
# raise this whenever modifying any of the support_code_files
return
(
0
,
23
)
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of
# these files
files
=
[
'corr3d_gemm.cu'
]
codes
=
[
open
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
f
))
.
read
()
for
f
in
files
]
return
reduce
(
str
.
__add__
,
codes
)
def
c_code_helper
(
self
,
bottom
,
weights
,
top
,
direction
,
sub
,
height
=
None
,
width
=
None
,
depth
=
None
):
"""
This generates the C code for GpuCorrMM (direction="forward"),
GpuCorrMM_gradWeights (direction="backprop weights"), and
GpuCorrMM_gradInputs (direction="backprop inputs").
Depending on the direction, one of bottom, weights, top will
receive the output, while the other two serve as inputs.
:param bottom: Variable name of the input images in the forward pass,
or the gradient of the input images in backprop wrt. inputs
:param weights: Variable name of the filters in the forward pass,
or the gradient of the filters in backprop wrt. weights
:param top: Variable name of the output images / feature maps in the
forward pass, or the gradient of the outputs in the backprop passes
:param direction: "forward" to correlate bottom with weights and store
results in top,
"backprop weights" to do a valid convolution of bottom with top
(swapping the first two dimensions) and store results in weights,
and "backprop inputs" to do a full convolution of top with weights
(swapping the first two dimensions) and store results in bottom.
:param sub: Dictionary of substitutions useable to help generating the
C code.
:param height: If self.subsample[0] != 1, a variable giving the height
of the filters for direction="backprop weights" or the height of the
input images for direction="backprop inputs".
If self.pad == 'half', a variable giving the height of the filters
for direction="backprop weights".
Ignored otherwise.
:param width: If self.subsample[1] != 1, a variable giving the width
of the filters for direction="backprop weights" or the width of the
input images for direction="backprop inputs".
If self.pad == 'half', a variable giving the width of the filters
for direction="backprop weights".
Ignored otherwise.
:param depth: If self.subsample[2] != 1, a variable giving the depth
of the filters for direction="backprop weights" or the depth of the
input images for direction="backprop inputs".
If self.pad == 'half', a variable giving the depth of the filters
for direction="backprop weights".
Ignored otherwise.
"""
if
self
.
border_mode
!=
"valid"
:
raise
ValueError
(
"mode must be 'valid'"
)
dH
,
dW
,
dD
=
self
.
subsample
if
self
.
pad
==
"half"
:
padH
=
padW
=
padD
=
-
1
elif
self
.
pad
==
"full"
:
padH
=
padW
=
padD
=-
2
else
:
padH
,
padW
,
padD
=
self
.
pad
if
direction
==
"forward"
:
direction
=
0
out
=
top
elif
direction
==
"backprop weights"
:
direction
=
1
out
=
weights
elif
direction
==
"backprop inputs"
:
direction
=
2
out
=
bottom
else
:
raise
ValueError
(
"direction must be one of 'forward', "
"'backprop weights', 'backprop inputs'"
)
# When subsampling, we cannot unambiguously infer the height and width
# of bottom and weights from top, so we require them to be given.
# Similarly, when pad="half", we cannot infer the weight size.
if
((
direction
!=
0
)
and
(
dH
!=
1
))
or
((
direction
==
1
)
and
(
padH
==
-
1
)):
if
not
height
:
raise
ValueError
(
"height must be given for backprop with vertical sampling or pad='half'"
)
height
=
'(*(npy_int*)(PyArray_DATA(
%
s)))'
%
height
else
:
height
=
'NULL'
if
((
direction
!=
0
)
and
(
dW
!=
1
))
or
((
direction
==
1
)
and
(
padW
==
-
1
)):
if
not
width
:
raise
ValueError
(
"width must be given for backprop with horizontal sampling or pad='half'"
)
width
=
'(*(npy_int*)(PyArray_DATA(
%
s)))'
%
width
else
:
width
=
'NULL'
if
((
direction
!=
0
)
and
(
dD
!=
1
))
or
((
direction
==
1
)
and
(
padD
==
-
1
)):
if
not
depth
:
raise
ValueError
(
"depth must be given for backprop with horizontal sampling or pad='half'"
)
depth
=
'(*(npy_int*)(PyArray_DATA(
%
s)))'
%
depth
else
:
depth
=
'NULL'
sub
=
sub
.
copy
()
sub
.
update
(
locals
())
return
"""
// Mandatory args
int direction =
%(direction)
s; // forward, bprop weights, bprop inputs
// Optional args
int dH =
%(dH)
s;
int dW =
%(dW)
s;
int dD =
%(dD)
s;
int padH =
%(padH)
s;
int padW =
%(padW)
s;
int padD =
%(padD)
s;
CudaNdarray * bottom =
%(bottom)
s;
CudaNdarray * weights =
%(weights)
s;
CudaNdarray * top =
%(top)
s;
CudaNdarray * out2 = NULL;
// Obtain or infer kernel width and height
// (we need to know it early to be able to handle auto-padding)
int kH, kW, kD;
if (direction != 1)
{
// weight is an input variable, we can just read its shape
kH = CudaNdarray_HOST_DIMS(weights)[2];
kW = CudaNdarray_HOST_DIMS(weights)[3];
kD = CudaNdarray_HOST_DIMS(weights)[4];
}
else
{
if ((dH != 1) || (padH == -1))
{
// vertical subsampling or half padding, kernel height is specified
kH =
%(height)
s;
}
else if (padH == -2)
{
// vertical full padding, we can infer the kernel height
kH = 2 - CudaNdarray_HOST_DIMS(bottom)[2] + (CudaNdarray_HOST_DIMS(top)[2] - 1) * dH;
}
else
{
// explicit padding, we can infer the kernel height
kH = CudaNdarray_HOST_DIMS(bottom)[2] + 2*padH - (CudaNdarray_HOST_DIMS(top)[2] - 1) * dH;
}
if ((dW != 1) || (padW == -1))
{
kW =
%(width)
s;
}
else if (padW == -2)
{
kW = 2 - CudaNdarray_HOST_DIMS(bottom)[3] + (CudaNdarray_HOST_DIMS(top)[3] - 1) * dW;
}
else
{
kW = CudaNdarray_HOST_DIMS(bottom)[3] + 2*padW - (CudaNdarray_HOST_DIMS(top)[3] - 1) * dW;
}
if ((dD != 1) || (padD == -1))
{
kD =
%(depth)
s;
}
else if (padD == -2)
{
kD = 2 - CudaNdarray_HOST_DIMS(bottom)[4] + (CudaNdarray_HOST_DIMS(top)[4] - 1) * dD;
}
else
{
kD = CudaNdarray_HOST_DIMS(bottom)[4] + 2*padD - (CudaNdarray_HOST_DIMS(top)[4] - 1) * dD;
}
}
// Auto-padding if requested
if (padH == -1)
{ // vertical half padding
padH = kH / 2;
}
else if (padH == -2)
{ // vertical full padding
padH = kH - 1;
}
else if (padH < 0)
{
PyErr_SetString(PyExc_ValueError, "BaseGpuCorr3dMM: padH must be >= -2");
%(fail)
s
}
if (padW == -1) { // horizontal half padding
padW = kW / 2;
}
else if (padW == -2) { // horizontal full padding
padW = kW - 1;
}
else if (padW < 0)
{
PyErr_SetString(PyExc_ValueError, "BaseGpuCorr3dMM: padW must be >= -2");
%(fail)
s
}
if (padD == -1)
{ // horizontal half padding
padD = kD / 2;
}
else if (padD == -2)
{ // horizontal full padding
padD = kD - 1;
}
else if (padD < 0)
{
PyErr_SetString(PyExc_ValueError, "BaseGpuCorr3dMM: padD must be >= -2");
%(fail)
s
}
// Infer output shape
int out_dim[5];
switch(direction) {
case 0: // forward pass
// output is top: (batchsize, num_filters, height, width, depth)
// height and width: top = (bottom + 2*pad - weight) / sample + 1
out_dim[0] = CudaNdarray_HOST_DIMS(bottom)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(weights)[0];
out_dim[2] = (CudaNdarray_HOST_DIMS(bottom)[2] + 2*padH - CudaNdarray_HOST_DIMS(weights)[2]) / dH + 1;
out_dim[3] = (CudaNdarray_HOST_DIMS(bottom)[3] + 2*padW - CudaNdarray_HOST_DIMS(weights)[3]) / dW + 1;
out_dim[4] = (CudaNdarray_HOST_DIMS(bottom)[4] + 2*padD - CudaNdarray_HOST_DIMS(weights)[4]) / dD + 1;
break;
case 1: // backprop wrt. weights
// output is weights: (num_filters, num_channels, height, width, depth)
// height, width and depth: weights = bottom + 2*pad - (top-1) * sample
out_dim[0] = CudaNdarray_HOST_DIMS(top)[1];
out_dim[1] = CudaNdarray_HOST_DIMS(bottom)[1];
out_dim[2] = kH; // already inferred further above
out_dim[3] = kW; // how convenient
out_dim[4] = kD;
break;
case 2: // backprop wrt. inputs
// output is bottom: (batchsize, num_channels, height, width, depth)
// height, width and depth: bottom = (top-1) * sample + weights - 2*pad
out_dim[0] = CudaNdarray_HOST_DIMS(top)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(weights)[1];
out_dim[2] = (dH != 1) ?
%(height)
s : (CudaNdarray_HOST_DIMS(top)[2] - 1) * dH + CudaNdarray_HOST_DIMS(weights)[2] - 2*padH;
out_dim[3] = (dW != 1) ?
%(width)
s : (CudaNdarray_HOST_DIMS(top)[3] - 1) * dW + CudaNdarray_HOST_DIMS(weights)[3] - 2*padW;
out_dim[4] = (dD != 1) ?
%(depth)
s : (CudaNdarray_HOST_DIMS(top)[4] - 1) * dD + CudaNdarray_HOST_DIMS(weights)[4] - 2*padD;
break;
default:
PyErr_SetString(PyExc_ValueError, "BaseGpuCorr3dMM: direction must be 0, 1, or 2
\\
n");
%(fail)
s
}
// Prepare output array
if (!(
%(out)
s
&&
%(out)
s->nd == 5
&& CudaNdarray_is_c_contiguous(
%(out)
s)
&& CudaNdarray_HOST_DIMS(
%(out)
s)[0] == out_dim[0]
&& CudaNdarray_HOST_DIMS(
%(out)
s)[1] == out_dim[1]
&& CudaNdarray_HOST_DIMS(
%(out)
s)[2] == out_dim[2]
&& CudaNdarray_HOST_DIMS(
%(out)
s)[3] == out_dim[3]
&& CudaNdarray_HOST_DIMS(
%(out)
s)[4] == out_dim[4]))
{
Py_XDECREF(
%(out)
s);
%(out)
s = (CudaNdarray*)CudaNdarray_NewDims(5, out_dim);
if (NULL ==
%(out)
s)
{
PyErr_Format(PyExc_RuntimeError,
"BaseGpuCorr3dM: Failed to allocate output of
%%
d x
%%
d x
%%
d x
%%
d x
%%
d",
out_dim[0], out_dim[1], out_dim[2], out_dim[3], out_dim[4]);
%(fail)
s
}
}
// Call CUDA code
out2 = corr3dMM(
%(bottom)
s,
%(weights)
s,
%(top)
s, direction, dH, dW, dD, padH, padW, padD);
if (out2==NULL){
%(fail)
s
}
assert (out2 ==
%(out)
s);
"""
%
sub
class
GpuCorr3dMM
(
BaseGpuCorr3dMM
):
"""GPU correlation implementation using Matrix Multiplication.
:warning: For 700 series Nvidia GPUs of compute capability 3.5 and CUDA 5.0
to 6.0, there is a bug in CUBLAS' matrix multiplication function that
can make GpuCorrMM or its gradients crash for some input and filter
shapes. So if you have a Tesla K20, Tesla K40, Quadro K6000, GeForce GT
640 (DDR5), GeForce GTX 780 (or Ti), GeForce GTX TITAN (or Black or Z)
and experience a crash, switching to CUDA 6.5 or CUDA 4.2 should fix it.
If this is not possible, changing the input or filter shapes (e.g., the
batchsize or number of filters) may also work around the CUBLAS bug.
"""
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
pad
=
(
0
,
0
,
0
)):
"""
:param border_mode: currently supports "valid" only; "full" can be
simulated by setting `pad="full"` (at the cost of performance), or
by using `GpuCorrMM_gradInputs`
:param subsample: the subsample operation applied to each output image.
Should be a tuple with 3 elements.
`(sv, sh, sl)` is equivalent to `GpuCorrMM(...)(...)[:,:,::sv, ::sh, ::sl]`,
but faster.
Set to `(1, 1, 1)` to disable subsampling.
:param pad: the width of a border of implicit zeros to pad the input
image with. Should be a tuple with 3 elements giving the numbers of
rows and columns to pad on each side, or "half" to set the padding
to `(kernel_rows // 2, kernel_columns // 2, kernel_depth // 2)`, or "full" to set the
padding to `(kernel_rows - 1, kernel_columns - 1, kernel_depth - 1)` at runtime.
Set to `(0, 0, 0)` to disable padding.
:note: Currently, the Op requires the inputs, filters and outputs to be
C-contiguous. Use :func:`gpu_contiguous
<theano.sandbox.cuda.basic_ops.gpu_contiguous>` on these arguments
if needed.
"""
super
(
GpuCorr3dMM
,
self
)
.
__init__
(
border_mode
,
subsample
,
pad
)
def
make_node
(
self
,
img
,
kern
):
img
=
as_cuda_ndarray_variable
(
img
)
kern
=
as_cuda_ndarray_variable
(
kern
)
if
img
.
type
.
ndim
!=
5
:
raise
TypeError
(
'img must be 5D tensor'
)
if
kern
.
type
.
ndim
!=
5
:
raise
TypeError
(
'kern must be 5D tensor'
)
broadcastable
=
[
img
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
0
],
False
,
False
,
False
]
return
Apply
(
self
,
[
img
,
kern
],
[
CudaNdarrayType
(
broadcastable
)()])
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
bottom
,
weights
=
inp
top
,
=
out_
direction
=
"forward"
return
super
(
GpuCorr3dMM
,
self
)
.
c_code_helper
(
bottom
,
weights
,
top
,
direction
,
sub
)
def
grad
(
self
,
inp
,
grads
):
bottom
,
weights
=
inp
top
,
=
grads
top
=
gpu_contiguous
(
top
)
d_bottom
=
GpuCorr3dMM_gradInputs
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
weights
,
top
,
bottom
.
shape
[
-
3
:])
d_weights
=
GpuCorr3dMM_gradWeights
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
top
,
weights
.
shape
[
-
3
:])
return
d_bottom
,
d_weights
class
GpuCorr3dMM_gradWeights
(
BaseGpuCorr3dMM
):
"""Gradient wrt. filters for `GpuCorr3dMM`.
:note: You will not want to use this directly, but rely on Theano's
automatic differentiation or graph optimization to use it as needed."""
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
pad
=
(
0
,
0
,
0
)):
super
(
GpuCorr3dMM_gradWeights
,
self
)
.
__init__
(
border_mode
,
subsample
,
pad
)
def
make_node
(
self
,
img
,
topgrad
,
shape
=
None
):
img
=
as_cuda_ndarray_variable
(
img
)
topgrad
=
as_cuda_ndarray_variable
(
topgrad
)
if
shape
is
not
None
:
shape
=
as_tensor_variable
(
shape
)
if
img
.
type
.
ndim
!=
5
:
raise
TypeError
(
'img must be 5D tensor'
)
if
topgrad
.
type
.
ndim
!=
5
:
raise
TypeError
(
'topgrad must be 5D tensor'
)
if
self
.
subsample
!=
(
1
,
1
,
1
)
or
self
.
pad
==
"half"
:
if
shape
is
None
:
raise
ValueError
(
'shape must be given if subsample != (1, 1, 1), or pad == "half"'
)
height_width_depth
=
[
shape
[
0
],
shape
[
1
],
shape
[
2
]]
else
:
height_width_depth
=
[]
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
1
],
img
.
type
.
broadcastable
[
1
],
False
,
False
,
False
]
return
Apply
(
self
,
[
img
,
topgrad
]
+
height_width_depth
,
[
CudaNdarrayType
(
broadcastable
)()])
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
bottom
,
top
=
inp
[:
2
]
height
,
width
,
depth
=
inp
[
2
:]
or
(
None
,
None
,
None
)
weights
,
=
out_
direction
=
"backprop weights"
return
super
(
GpuCorr3dMM_gradWeights
,
self
)
.
c_code_helper
(
bottom
,
weights
,
top
,
direction
,
sub
,
height
,
width
,
depth
)
def
grad
(
self
,
inp
,
grads
):
bottom
,
top
=
inp
[:
2
]
weights
,
=
grads
weights
=
gpu_contiguous
(
weights
)
d_bottom
=
GpuCorr3dMM_gradInputs
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
weights
,
top
,
bottom
.
shape
[
-
3
:])
d_top
=
GpuCorr3dMM
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
weights
)
d_height_width_depth
=
(
theano
.
gradient
.
DisconnectedType
()(),)
*
3
if
len
(
inp
)
==
5
else
()
return
(
d_bottom
,
d_top
)
+
d_height_width_depth
def
connection_pattern
(
self
,
node
):
if
node
.
nin
==
2
:
return
[[
1
],
[
1
]]
else
:
return
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
]]
# no connection to height, width, depth
class
GpuCorr3dMM_gradInputs
(
BaseGpuCorr3dMM
):
"""Gradient wrt. inputs for `GpuCorr3dMM`.
:note: You will not want to use this directly, but rely on Theano's
automatic differentiation or graph optimization to use it as needed."""
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
,
1
),
pad
=
(
0
,
0
,
0
)):
super
(
GpuCorr3dMM_gradInputs
,
self
)
.
__init__
(
border_mode
,
subsample
,
pad
)
def
make_node
(
self
,
kern
,
topgrad
,
shape
=
None
):
kern
=
as_cuda_ndarray_variable
(
kern
)
topgrad
=
as_cuda_ndarray_variable
(
topgrad
)
if
kern
.
type
.
ndim
!=
5
:
raise
TypeError
(
'kern must be 5D tensor'
)
if
topgrad
.
type
.
ndim
!=
5
:
raise
TypeError
(
'topgrad must be 5D tensor'
)
if
self
.
subsample
!=
(
1
,
1
,
1
)
and
shape
is
None
:
raise
ValueError
(
'shape must be given if subsample != (1, 1, 1)'
)
height_width_depth
=
[
shape
[
0
],
shape
[
1
],
shape
[
2
]]
if
self
.
subsample
!=
(
1
,
1
,
1
)
else
[]
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
1
],
False
,
False
,
False
]
return
Apply
(
self
,
[
kern
,
topgrad
]
+
height_width_depth
,
[
CudaNdarrayType
(
broadcastable
)()])
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
weights
,
top
=
inp
[:
2
]
height
,
width
,
depth
=
inp
[
2
:]
or
(
None
,
None
,
None
)
bottom
,
=
out_
direction
=
"backprop inputs"
return
super
(
GpuCorr3dMM_gradInputs
,
self
)
.
c_code_helper
(
bottom
,
weights
,
top
,
direction
,
sub
,
height
,
width
,
depth
)
def
grad
(
self
,
inp
,
grads
):
weights
,
top
=
inp
[:
2
]
bottom
,
=
grads
bottom
=
gpu_contiguous
(
bottom
)
d_weights
=
GpuCorr3dMM_gradWeights
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
top
,
weights
.
shape
[
-
3
:])
d_top
=
GpuCorr3dMM
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
weights
)
d_height_width_depth
=
(
theano
.
gradient
.
DisconnectedType
()(),)
*
3
if
len
(
inp
)
==
5
else
()
return
(
d_weights
,
d_top
)
+
d_height_width_depth
def
connection_pattern
(
self
,
node
):
if
node
.
nin
==
2
:
return
[[
1
],
[
1
]]
else
:
return
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
]]
# no connection to height, width, depth
##
##
# Not really a BLAS operation, but whatever.
# Not really a BLAS operation, but whatever.
#
#
...
...
theano/sandbox/cuda/corr3d_gemm.cu
0 → 100644
浏览文件 @
73b55c61
// This uses a lot of code from Caffe (http://caffe.berkeleyvision.org/);
// sources are clearly marked. Below we reproduce the original license of
// the Caffe software.
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#undef _GLIBCXX_ATOMIC_BUILTINS
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/caffe_common.hpp)
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CUDA_NUM_THREADS = 1024;
#else
const int CUDA_NUM_THREADS = 512;
#endif
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
// (Adapted from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu)
// Kernels for fast unfold + copy
__global__ void im3d2col_kernel(const int n, const float* data_im,
const int height, const int width, const int depth,
const int kernel_h, const int kernel_w, const int kernel_d,
const int pad_h, const int pad_w, const int pad_d,
const int stride_h, const int stride_w, const int stride_d,
const int height_col, const int width_col, const int depth_col,
float* data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
int d_out = index % depth_col;
int w_index = index / depth_col;
int w_out = w_index % width_col;
int h_index = w_index / width_col;
int h_out = h_index % height_col;
int channel_in = h_index / height_col;
//channel_in = 1;
int channel_out = channel_in * kernel_h * kernel_w * kernel_d;
int h_in = h_out * stride_h - pad_h;
int w_in = w_out * stride_w - pad_w;
int d_in = d_out * stride_d - pad_d;
float* data_col_ptr = data_col;
data_col_ptr += channel_out * (height_col * width_col * depth_col) +
h_out * (width_col * depth_col) + w_out * depth_col + d_out;
const float* data_im_ptr = data_im;
data_im_ptr += channel_in * (height * width * depth) +
h_in * (width * depth) + w_in * depth + d_in;
for (int i = 0; i < kernel_h; ++i)
{
int h = h_in + i;
for (int j = 0; j < kernel_w; ++j)
{
int w = w_in + j;
for (int k = 0; k < kernel_d; ++k)
{
int d = d_in + k;
*data_col_ptr = (h >= 0 && w >= 0 && d >= 0 &&
h < height && w < width && d < depth) ?
data_im_ptr[i * (width * depth) + j *depth + k] : 0;
data_col_ptr += height_col * width_col * depth_col;
}
}
}
}
}
void im3d2col(const float* data_im, const int channels,
const int height, const int width, const int depth,
const int kernel_h, const int kernel_w, const int kernel_d,
const int pad_h, const int pad_w, const int pad_d,
const int stride_h, const int stride_w, const int stride_d,
float* data_col)
{
// We are going to launch channels * height_col * width_col * depth_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int depth_col = (depth + 2 * pad_d - kernel_d) / stride_d + 1;
int num_kernels = channels * height_col * width_col * depth_col;
im3d2col_kernel<<<GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(num_kernels, data_im,
height, width, depth,
kernel_h, kernel_w, kernel_d,
pad_h, pad_w, pad_d,
stride_h, stride_w, stride_d,
height_col, width_col, depth_col,
data_col);
}
__global__ void col2im3d_kernel(const int n, const float* data_col,
const int height, const int width, const int depth,
const int channels,
const int patch_h, const int patch_w, const int patch_d,
const int pad_h, const int pad_w, const int pad_d,
const int stride_h, const int stride_w, const int stride_d,
const int height_col, const int width_col, const int depth_col,
float* data_im)
{
CUDA_KERNEL_LOOP(index, n)
{
float val = 0;
int d = index % depth + pad_d;
int w_index = index / depth;
int w = w_index % width + pad_w;
int h_index = w_index / width;
int h = h_index % height + pad_h;
int c = h_index / height;
// compute the start and end of the output
int d_col_start = (d < patch_d) ? 0 : (d - patch_d) / stride_d + 1;
int d_col_end = min(d / stride_d + 1, depth_col);
int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1;
int w_col_end = min(w / stride_w + 1, width_col);
int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1;
int h_col_end = min(h / stride_h + 1, height_col);
int offset =
(c * patch_h * patch_w * patch_d + h * patch_w * patch_d + w * patch_d + d) * height_col * width_col * depth_col;
int coeff_h_col = (1 - stride_h * patch_w * patch_d * height_col) * width_col * depth_col;
int coeff_w_col = (1 - stride_w * patch_d * height_col * width_col) * depth_col;
int coeff_d_col = (1 - stride_d * height_col * width_col * depth_col);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col)
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col + d_col * coeff_d_col];
}
}
data_im[index] = val;
}
}
void col2im3d(const float* data_col, const int channels,
const int height, const int width, const int depth,
const int patch_h, const int patch_w, const int patch_d,
const int pad_h, const int pad_w, const int pad_d,
const int stride_h, const int stride_w, const int stride_d,
float* data_im)
{
int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
int depth_col = (depth + 2 * pad_d - patch_d) / stride_d + 1;
int num_kernels = channels * height * width * depth;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im3d_kernel<<<GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(num_kernels, data_col,
height, width, depth, channels,
patch_h, patch_w, patch_d,
pad_h, pad_w, pad_d,
stride_h, stride_w, stride_d,
height_col, width_col, depth_col,
data_im);
}
// Theano op code
// Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter, Nicolas Ballas
// Reference code: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
// and https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
// Adaptation for 3d
CudaNdarray* corr3dMM(CudaNdarray *const bottom,
CudaNdarray *const weight,
CudaNdarray *const top,
const int direction,
const int dH = 1,
const int dW = 1,
const int dD = 1,
const int padH = 0,
const int padW = 0,
const int padD = 0)
{
if (bottom->nd != 5)
{
PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires bottom of 5D");
return NULL;
}
if (!CudaNdarray_is_c_contiguous(bottom))
{
PyErr_Format(PyExc_ValueError,
"GpuCorr3dMM requires bottom to be C-contiguous, "
"but strides are: %d %d %d %d %d\n",
CudaNdarray_HOST_STRIDES(bottom)[0],
CudaNdarray_HOST_STRIDES(bottom)[1],
CudaNdarray_HOST_STRIDES(bottom)[2],
CudaNdarray_HOST_STRIDES(bottom)[3],
CudaNdarray_HOST_STRIDES(bottom)[4]);
return 0;
}
if (weight->nd != 5)
{
PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires weight of 5D");
return 0;
}
if (!CudaNdarray_is_c_contiguous(weight))
{
PyErr_Format(PyExc_ValueError,
"GpuCorr3dMM requires weight to be C-contiguous, "
"but strides are: %d %d %d %d %d\n",
CudaNdarray_HOST_STRIDES(weight)[0],
CudaNdarray_HOST_STRIDES(weight)[1],
CudaNdarray_HOST_STRIDES(weight)[2],
CudaNdarray_HOST_STRIDES(weight)[3],
CudaNdarray_HOST_STRIDES(weight)[4]);
return 0;
}
if (top->nd != 5)
{
PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires top of 5D");
return 0;
}
if (!CudaNdarray_is_c_contiguous(top))
{
PyErr_Format(PyExc_ValueError,
"GpuCorr3dMM requires top to be C-contiguous, "
"but strides are: %d %d %d %d %d\n",
CudaNdarray_HOST_STRIDES(top)[0],
CudaNdarray_HOST_STRIDES(top)[1],
CudaNdarray_HOST_STRIDES(top)[2],
CudaNdarray_HOST_STRIDES(top)[3],
CudaNdarray_HOST_STRIDES(top)[4]);
return 0;
}
// Extract some shape information for later and check shape consistency
// bottom: (batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth)
const int batchSize = CudaNdarray_HOST_DIMS(bottom)[0];
const int nChannels = CudaNdarray_HOST_DIMS(bottom)[1];
const int bottomHeight = CudaNdarray_HOST_DIMS(bottom)[2];
const int bottomWidth = CudaNdarray_HOST_DIMS(bottom)[3];
const int bottomDepth = CudaNdarray_HOST_DIMS(bottom)[4];
// weights: (nFilters, nChannels, rows, columns, depth)
const int nFilters = CudaNdarray_HOST_DIMS(weight)[0];
const int kH = CudaNdarray_HOST_DIMS(weight)[2];
const int kW = CudaNdarray_HOST_DIMS(weight)[3];
const int kD = CudaNdarray_HOST_DIMS(weight)[4];
if (nChannels != CudaNdarray_HOST_DIMS(weight)[1])
{
PyErr_SetString(PyExc_ValueError,
"GpuCorr3dMM images and kernel must have the same stack size\n");
return 0;
}
// top: (batchSize, nFilters, topHeight, topWidth, topDepth)
const int topHeight = int((bottomHeight + 2*padH - kH) / dH) + 1;
const int topWidth = int((bottomWidth + 2*padW - kW) / dW) + 1;
const int topDepth = int((bottomDepth + 2*padD - kD) / dD) + 1;
if (batchSize != CudaNdarray_HOST_DIMS(top)[0] ||
nFilters != CudaNdarray_HOST_DIMS(top)[1] ||
topHeight != CudaNdarray_HOST_DIMS(top)[2] ||
topWidth != CudaNdarray_HOST_DIMS(top)[3] ||
topDepth != CudaNdarray_HOST_DIMS(top)[4])
{
PyErr_Format(PyExc_ValueError,
"GpuCorr3dMM shape inconsistency:\n"
" bottom shape: %d %d %d %d %d\n"
" weight shape: %d %d %d %d %d\n"
" top shape: %d %d %d %d %d (expected %d %d %d %d %d)\n",
batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth,
nFilters, nChannels, kH, kW, kD,
CudaNdarray_HOST_DIMS(top)[0], CudaNdarray_HOST_DIMS(top)[1],
CudaNdarray_HOST_DIMS(top)[2], CudaNdarray_HOST_DIMS(top)[3],
CudaNdarray_HOST_DIMS(top)[4],
batchSize, nFilters, topHeight, topWidth, topDepth);
return 0;
}
// Create temporary columns
int col_dim[2];
col_dim[0] = nChannels * kW * kH * kD;
col_dim[1] = topHeight * topWidth * topDepth;
CudaNdarray* col = (CudaNdarray*) CudaNdarray_NewDims(2, col_dim);
if (0 == col)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM failed to allocate working memory of %d x %d\n",
col_dim[0], col_dim[1]);
return 0;
}
// Define some useful variables
const int bottom_stride = CudaNdarray_HOST_STRIDES(bottom)[0];
const int top_stride = CudaNdarray_HOST_STRIDES(top)[0];
const int K_ = col_dim[0];
const int N_ = col_dim[1];
const int M_ = nFilters;
const float one = 1.0f;
const float zero = 0.0f;
CudaNdarray *output;
if (direction == 0)
{ // forward pass
output = top;
// valid correlation: im2col, then gemm
// Iterate over batch
for (int n = 0; n < batchSize; n++)
{
// First, im3d2col
im3d2col(bottom->devdata + n * bottom_stride,
nChannels,
bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD,
padH, padW, padD,
dH, dW, dD,
col->devdata);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered a CUDA error in im2col: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorr3dMM() documentation.\n",
cudaGetErrorString(err));
Py_DECREF(col);
return 0;
}
// Second, gemm
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
N_, M_, K_,
&one,
col->devdata, N_,
weight->devdata, K_,
&zero,
top->devdata + n * top_stride, N_);
if (status != CUBLAS_STATUS_SUCCESS)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered a CUBLAS error: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorr3dMM() documentation.\n",
cublasGetErrorString(status));
Py_DECREF(col);
return 0;
}
}
}
else if (direction == 1)
{
// backprop wrt. weights
output = weight;
// valid convolution: im2col, then gemm
// Iterate over batch
for (int n = 0; n < batchSize; n++)
{
// First, im2col
im3d2col(bottom->devdata + n * bottom_stride, nChannels,
bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD,
padH, padW, padD,
dH, dW, dD,
col->devdata);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered a CUDA error in im2col: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorr3dMM() documentation.\n",
cudaGetErrorString(err));
Py_DECREF(col);
return 0;
}
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.)
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_T, CUBLAS_OP_N,
K_, M_, N_,
&one,
col->devdata, N_,
top->devdata + n * top_stride, N_,
(n == 0) ? &zero : &one,
weight->devdata, K_);
if (status != CUBLAS_STATUS_SUCCESS)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered a CUBLAS error: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorr3dMM() documentation.\n",
cublasGetErrorString(status));
Py_DECREF(col);
return 0;
}
}
}
else if (direction == 2)
{
// backprop wrt. inputs
output = bottom;
// full convolution: gemm, then col2im3d
// Iterate over batch
for (int n = 0; n < batchSize; n++)
{
// gemm into columns
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_T,
N_, K_, M_,
&one,
top->devdata + n * top_stride, N_,
weight->devdata, K_,
&zero,
col->devdata, N_);
if (status != CUBLAS_STATUS_SUCCESS)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered a CUBLAS error: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorr3dMM() documentation.\n",
cublasGetErrorString(status));
Py_DECREF(col);
return 0;
}
// col2im3d back to the data
col2im3d(col->devdata, nChannels,
bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD,
padH, padW, padD,
dH, dW, dD, bottom->devdata + n * bottom_stride);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered a CUDA error in col2im: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorr3dMM() documentation.\n",
cudaGetErrorString(err));
Py_DECREF(col);
return 0;
}
}
}
// Free temporary columns
Py_DECREF(col);
// Note that we don't change the refcount of the output matrix here. Output
// allocation and refcounting is done in BaseGpuCorr3dMM.c_code_helper();
// in here output is just aliased to one of bottom, weights, or top.
return output;
}
theano/sandbox/cuda/co
nv
_gemm.cu
→
theano/sandbox/cuda/co
rr
_gemm.cu
浏览文件 @
73b55c61
...
@@ -294,6 +294,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -294,6 +294,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
return NULL;
}
}
// Second, gemm
// Second, gemm
...
@@ -311,6 +312,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -311,6 +312,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
cublasGetErrorString(status));
Py_DECREF(col);
return NULL;
return NULL;
}
}
}
}
...
@@ -359,6 +361,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -359,6 +361,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
return NULL;
}
}
// Second, gemm
// Second, gemm
...
@@ -379,6 +382,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -379,6 +382,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
cublasGetErrorString(status));
Py_DECREF(col);
return NULL;
return NULL;
}
}
}
}
...
@@ -429,6 +433,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -429,6 +433,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
cublasGetErrorString(status));
Py_DECREF(col);
return NULL;
return NULL;
}
}
// col2im back to the data
// col2im back to the data
...
@@ -441,6 +446,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -441,6 +446,7 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
"This could be a known bug in CUDA, please see the "
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
return NULL;
}
}
}
}
...
...
theano/sandbox/cuda/opt.py
浏览文件 @
73b55c61
...
@@ -26,7 +26,8 @@ from theano.sandbox.cuda.basic_ops import (
...
@@ -26,7 +26,8 @@ from theano.sandbox.cuda.basic_ops import (
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
,
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
,
GpuCorrMM
,
GpuCorrMM_gradInputs
,
GpuCorrMM_gradWeights
)
GpuCorrMM
,
GpuCorrMM_gradInputs
,
GpuCorrMM_gradWeights
,
GpuCorr3dMM
,
GpuCorr3dMM_gradInputs
,
GpuCorr3dMM_gradWeights
)
from
theano.sandbox.cuda.blas
import
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
...
@@ -1338,6 +1339,74 @@ def local_convtransp3d_fft(node):
...
@@ -1338,6 +1339,74 @@ def local_convtransp3d_fft(node):
gpu_optimizer
.
register
(
"convtransp3d_fft"
,
local_convtransp3d_fft
)
gpu_optimizer
.
register
(
"convtransp3d_fft"
,
local_convtransp3d_fft
)
@local_optimizer
([
Conv3D
])
def
local_conv3d_gemm
(
node
):
if
not
isinstance
(
node
.
op
,
Conv3D
):
return
try
:
sx
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
3
][
0
])
sy
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
3
][
1
])
sz
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
3
][
2
])
except
tensor
.
NotScalarConstantError
:
return
False
if
isinstance
(
node
.
op
,
Conv3D
):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x
=
node
.
inputs
[
0
]
x
=
x
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
# Shuffle filters from (oc, 0, 1, t, ic) to (oc, ic, 0, 1, t)
f
=
node
.
inputs
[
1
]
f
=
f
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
)
rval
=
GpuCorr3dMM
(
border_mode
=
'valid'
,
subsample
=
(
sx
,
sy
,
sz
))(
x
,
f
)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return
[
rval
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
+
node
.
inputs
[
2
]]
gpu_optimizer
.
register
(
"conv3d_gemm"
,
local_conv3d_gemm
)
@local_optimizer
([
ConvGrad3D
])
def
local_convgrad3d_gemm
(
node
):
try
:
sx
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
0
])
sy
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
1
])
sz
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
2
])
except
tensor
.
NotScalarConstantError
:
return
False
if
isinstance
(
node
.
op
,
ConvGrad3D
):
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x
=
node
.
inputs
[
0
]
x
=
gpu_contiguous
(
x
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
# Shuffle dCdH from (b, 0, 1, t, oc) to (oc, b, 0, 1, t)
f
=
node
.
inputs
[
3
]
f
=
gpu_contiguous
(
f
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
rval
=
GpuCorr3dMM_gradWeights
(
subsample
=
(
sx
,
sy
,
sz
))(
x
,
f
,
shape
=
node
.
inputs
[
2
][
1
:
4
])
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
return
[
rval
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)]
gpu_optimizer
.
register
(
"convgrad3d_gemm"
,
local_convgrad3d_gemm
)
@local_optimizer
([
ConvTransp3D
])
def
local_convtransp3d_gemm
(
node
):
try
:
sx
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
2
][
0
])
sy
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
2
][
1
])
sz
=
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
2
][
2
])
except
tensor
.
NotScalarConstantError
:
return
False
if
isinstance
(
node
.
op
,
ConvTransp3D
)
and
(
sx
,
sy
,
sz
)
==
(
1
,
1
,
1
):
# Shuffle filters from (oc, 0, 1, t, ic) to (ic, oc, 0, 1, t)
x
=
node
.
inputs
[
0
]
x
=
gpu_contiguous
(
x
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
# Shuffle dCdH from (b, 0, 1, t, oc) to (b, oc, 0, 1, t)
f
=
node
.
inputs
[
3
]
f
=
gpu_contiguous
(
f
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
rval
=
GpuCorr3dMM_gradInputs
(
subsample
=
(
sx
,
sy
,
sz
))(
kern
=
x
,
topgrad
=
f
)
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
return
[
rval
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
+
node
.
inputs
[
1
]]
gpu_optimizer
.
register
(
"convtransp3d_gemm"
,
local_convtransp3d_gemm
)
import
theano.tensor.signal.downsample
as
downsample
import
theano.tensor.signal.downsample
as
downsample
...
...
theano/sandbox/cuda/tests/test_gemmcorr3d.py
0 → 100644
浏览文件 @
73b55c61
import
unittest
import
numpy
import
theano
from
theano.tests
import
unittest_tools
as
utt
# Skip tests if cuda_ndarray is not available.
from
nose.plugins.skip
import
SkipTest
import
theano.sandbox.cuda
as
cuda_ndarray
if
not
cuda_ndarray
.
cuda_available
:
raise
SkipTest
(
'Optional package cuda not available'
)
from
theano.sandbox.cuda
import
float32_shared_constructor
as
shared
from
theano.sandbox.cuda.blas
import
GpuCorr3dMM
,
GpuCorr3dMM_gradWeights
,
GpuCorr3dMM_gradInputs
from
theano.sandbox.cuda.basic_ops
import
gpu_contiguous
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
mode_with_gpu
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
including
(
'gpu'
)
else
:
mode_with_gpu
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
'gpu'
)
class
TestCorr3DMM
(
unittest
.
TestCase
):
def
run_conv_valid
(
self
,
inputs_shape
,
filters_shape
,
subsample
=
(
1
,
1
,
1
)):
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
filters_val
=
numpy
.
random
.
random
(
filters_shape
)
.
astype
(
'float32'
)
inputs
=
shared
(
inputs_val
)
filters
=
shared
(
filters_val
)
bias
=
shared
(
numpy
.
zeros
(
filters_shape
[
0
])
.
astype
(
'float32'
))
conv_ref
=
theano
.
tensor
.
nnet
.
conv3D
(
V
=
inputs
,
W
=
filters
,
b
=
bias
,
d
=
subsample
)
conv
=
GpuCorr3dMM
(
border_mode
=
"valid"
,
subsample
=
subsample
)(
inputs
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
),
filters
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
conv
=
conv
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
f_ref
=
theano
.
function
([],
conv_ref
)
f
=
theano
.
function
([],
conv
,
mode
=
mode_with_gpu
)
res_ref
=
f_ref
()
res
=
f
()
utt
.
assert_allclose
(
res_ref
,
res
)
def
test_valid
(
self
):
self
.
run_conv_valid
(
inputs_shape
=
(
16
,
20
,
12
,
16
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
))
self
.
run_conv_valid
(
inputs_shape
=
(
16
,
20
,
12
,
15
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
2
,
2
,
2
))
self
.
run_conv_valid
(
inputs_shape
=
(
16
,
20
,
12
,
15
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
2
,
2
,
2
))
self
.
run_conv_valid
(
inputs_shape
=
(
16
,
20
,
12
,
15
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
3
,
3
,
3
))
self
.
run_conv_valid
(
inputs_shape
=
(
16
,
20
,
12
,
15
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
3
,
3
,
3
))
self
.
run_conv_valid
(
inputs_shape
=
(
16
,
20
,
12
,
15
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
3
,
2
,
1
))
self
.
run_conv_valid
(
inputs_shape
=
(
16
,
20
,
12
,
15
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
1
,
2
,
3
))
def
run_gradweight
(
self
,
inputs_shape
,
filters_shape
,
dCdH_shape
,
subsample
=
(
1
,
1
,
1
)):
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
dCdH_val
=
numpy
.
random
.
random
(
dCdH_shape
)
.
astype
(
'float32'
)
inputs
=
shared
(
inputs_val
)
dCdH
=
shared
(
dCdH_val
)
conv
=
theano
.
tensor
.
nnet
.
convGrad3D
(
V
=
inputs
,
dCdH
=
dCdH
,
WShape
=
filters_shape
,
d
=
subsample
)
img
=
gpu_contiguous
(
inputs
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
topgrad
=
gpu_contiguous
(
dCdH
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
if
(
subsample
==
(
1
,
1
,
1
)):
conv_gemm
=
GpuCorr3dMM_gradWeights
(
subsample
=
subsample
)(
img
,
topgrad
)
else
:
conv_gemm
=
GpuCorr3dMM_gradWeights
(
subsample
=
subsample
)(
img
,
topgrad
,
shape
=
filters_shape
[
1
:
4
])
conv_gemm
=
conv_gemm
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
f_ref
=
theano
.
function
([],
conv
)
f
=
theano
.
function
([],
conv_gemm
)
res_ref
=
f_ref
()
res
=
f
()
utt
.
assert_allclose
(
res_ref
,
res
)
def
test_gradweight
(
self
):
self
.
run_gradweight
(
inputs_shape
=
(
16
,
10
,
12
,
16
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
dCdH_shape
=
(
16
,
5
,
1
,
13
,
10
),
subsample
=
(
1
,
1
,
1
))
self
.
run_gradweight
(
inputs_shape
=
(
16
,
20
,
10
,
16
,
1
),
filters_shape
=
(
10
,
6
,
4
,
4
,
1
),
dCdH_shape
=
(
16
,
8
,
4
,
7
,
10
),
subsample
=
(
2
,
2
,
2
))
self
.
run_gradweight
(
inputs_shape
=
(
16
,
20
,
10
,
16
,
1
),
filters_shape
=
(
10
,
6
,
3
,
4
,
1
),
dCdH_shape
=
(
16
,
5
,
3
,
5
,
10
),
subsample
=
(
3
,
3
,
3
))
self
.
run_gradweight
(
inputs_shape
=
(
16
,
20
,
12
,
16
,
1
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
dCdH_shape
=
(
16
,
8
,
1
,
5
,
10
),
subsample
=
(
2
,
1
,
3
))
def
run_gradinput
(
self
,
inputs_shape
,
filters_shape
,
subsample
=
(
1
,
1
,
1
)):
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
filters_val
=
numpy
.
random
.
random
(
filters_shape
)
.
astype
(
'float32'
)
inputs
=
shared
(
inputs_val
)
filters
=
shared
(
filters_val
)
bias
=
shared
(
numpy
.
zeros
(
filters_shape
[
4
])
.
astype
(
'float32'
))
conv
=
theano
.
tensor
.
nnet
.
convTransp3D
(
W
=
filters
,
b
=
bias
,
d
=
subsample
,
H
=
inputs
)
f_ref
=
theano
.
function
([],
conv
)
res_ref
=
f_ref
()
### Get bottom shape using convTransp3D
bottom_shape
=
res_ref
.
shape
bottom_val
=
numpy
.
random
.
random
(
bottom_shape
)
.
astype
(
'float32'
)
bottom
=
shared
(
bottom_val
)
weight
=
gpu_contiguous
(
filters
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
top
=
gpu_contiguous
(
inputs
.
dimshuffle
(
0
,
4
,
1
,
2
,
3
))
if
(
subsample
==
(
1
,
1
,
1
)):
conv_gemm
=
GpuCorr3dMM_gradInputs
(
subsample
=
subsample
)(
kern
=
weight
,
topgrad
=
top
)
else
:
conv_gemm
=
GpuCorr3dMM_gradInputs
(
subsample
=
subsample
)(
kern
=
weight
,
topgrad
=
top
,
shape
=
bottom
.
shape
[
1
:
4
])
conv_gemm
=
conv_gemm
.
dimshuffle
(
0
,
2
,
3
,
4
,
1
)
f
=
theano
.
function
([],
conv_gemm
)
res
=
f
()
utt
.
assert_allclose
(
res_ref
,
res
)
def
test_gradinput
(
self
):
self
.
run_gradinput
(
inputs_shape
=
(
16
,
15
,
12
,
12
,
10
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
))
self
.
run_gradinput
(
inputs_shape
=
(
16
,
15
,
12
,
12
,
10
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
2
,
2
,
2
))
self
.
run_gradinput
(
inputs_shape
=
(
16
,
15
,
12
,
12
,
10
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
3
,
3
,
3
))
self
.
run_gradinput
(
inputs_shape
=
(
16
,
15
,
12
,
12
,
10
),
filters_shape
=
(
10
,
6
,
12
,
4
,
1
),
subsample
=
(
3
,
1
,
2
))
def
test_opt_conv3d_gemm
(
self
):
inputs_shape
=
(
16
,
20
,
32
,
16
,
1
)
filters_shape
=
(
10
,
6
,
12
,
4
,
1
)
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
filters_val
=
numpy
.
random
.
random
(
filters_shape
)
.
astype
(
'float32'
)
inputs
=
shared
(
inputs_val
)
filters
=
shared
(
filters_val
)
bias
=
shared
(
numpy
.
zeros
(
filters_shape
[
0
])
.
astype
(
'float32'
))
conv
=
theano
.
tensor
.
nnet
.
conv3D
(
V
=
inputs
,
W
=
filters
,
b
=
bias
,
d
=
(
1
,
1
,
1
))
mode
=
mode_with_gpu
.
including
(
'conv3d_gemm'
)
f_ref
=
theano
.
function
([],
conv
)
f_gemm
=
theano
.
function
([],
conv
,
mode
=
mode
)
# make sure we inserted the gemm trickery
topo
=
f_gemm
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
n
.
op
,
GpuCorr3dMM
)
for
n
in
topo
)
>
0
res_ref
=
f_ref
()
res_gemm
=
f_gemm
()
utt
.
assert_allclose
(
res_ref
,
res_gemm
)
def
test_opt_convgrad3d_gemm
(
self
):
inputs_shape
=
(
16
,
10
,
12
,
16
,
1
)
filters_shape
=
(
10
,
6
,
12
,
4
,
1
)
dCdH_shape
=
(
16
,
5
,
1
,
13
,
10
)
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
dCdH_val
=
numpy
.
random
.
random
(
dCdH_shape
)
.
astype
(
'float32'
)
inputs
=
shared
(
inputs_val
)
dCdH
=
shared
(
dCdH_val
)
conv
=
theano
.
tensor
.
nnet
.
convGrad3D
(
V
=
inputs
,
dCdH
=
dCdH
,
WShape
=
filters_shape
,
d
=
(
1
,
1
,
1
))
mode
=
mode_with_gpu
.
including
(
'convgrad3d_gemm'
)
f_ref
=
theano
.
function
([],
conv
)
f_gemm
=
theano
.
function
([],
conv
,
mode
=
mode
)
# make sure we inserted the gemm trickery
topo
=
f_gemm
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
n
.
op
,
GpuCorr3dMM_gradWeights
)
for
n
in
topo
)
>
0
res_ref
=
f_ref
()
res_gemm
=
f_gemm
()
utt
.
assert_allclose
(
res_ref
,
res_gemm
)
def
test_opt_convtransp3d_gemm
(
self
):
inputs_shape
=
(
16
,
15
,
12
,
12
,
10
)
filters_shape
=
(
10
,
6
,
12
,
4
,
1
)
inputs_val
=
numpy
.
random
.
random
(
inputs_shape
)
.
astype
(
'float32'
)
filters_val
=
numpy
.
random
.
random
(
filters_shape
)
.
astype
(
'float32'
)
bias
=
shared
(
numpy
.
zeros
(
filters_shape
[
4
])
.
astype
(
'float32'
))
inputs
=
shared
(
inputs_val
)
filters
=
shared
(
filters_val
)
conv
=
theano
.
tensor
.
nnet
.
convTransp3D
(
W
=
filters
,
b
=
bias
,
d
=
(
1
,
1
,
1
),
H
=
inputs
)
mode
=
mode_with_gpu
.
including
(
'convtransp3d_gemm'
)
f_ref
=
theano
.
function
([],
conv
)
f_gemm
=
theano
.
function
([],
conv
,
mode
=
mode
)
# make sure we inserted the gemm trickery
topo
=
f_gemm
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
n
.
op
,
GpuCorr3dMM_gradInputs
)
for
n
in
topo
)
>
0
res_ref
=
f_ref
()
res_gemm
=
f_gemm
()
utt
.
assert_allclose
(
res_ref
,
res_gemm
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论