Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8ce2474e
提交
8ce2474e
authored
3月 07, 2014
作者:
Marc-Alexandre Cote
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleaned code and unit tests
上级
65f5d0c7
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
101 行增加
和
198 行删除
+101
-198
extra_ops.py
theano/sandbox/cuda/extra_ops.py
+52
-73
test_extra_ops.py
theano/sandbox/cuda/tests/test_extra_ops.py
+42
-125
extra_ops.py
theano/tensor/extra_ops.py
+4
-0
test_extra_ops.py
theano/tensor/tests/test_extra_ops.py
+3
-0
没有找到文件。
theano/sandbox/cuda/extra_ops.py
浏览文件 @
8ce2474e
...
@@ -13,21 +13,27 @@ if cuda_available:
...
@@ -13,21 +13,27 @@ if cuda_available:
class
GpuCumsum
(
CumsumOp
,
GpuOp
):
class
GpuCumsum
(
CumsumOp
,
GpuOp
):
def
__init__
(
self
,
axis
=
None
):
SUPPORTED_NDIMS
=
2
def
__init__
(
self
,
axis
):
"""
``axis`` can not be None. If you want the array flatten, do it before.
"""
self
.
axis
=
axis
self
.
axis
=
axis
self
.
max_threads_dim0
=
None
self
.
max_threads_dim0
=
None
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
assert
x
.
dtype
==
'float32'
assert
x
.
dtype
==
'float32'
if
not
isinstance
(
x
.
type
,
CudaNdarrayType
):
if
not
isinstance
(
x
.
type
,
CudaNdarrayType
):
raise
TypeError
(
'x must be
cudandarray
'
,
x
)
raise
TypeError
(
'x must be
a CudaNdarrayType
'
,
x
)
out_type
=
x
.
type
()
if
x
.
ndim
>
GpuCumsum
.
SUPPORTED_NDIMS
:
raise
NotImplementedError
(
'Only cumsum on 1D and 2D array are supported right now!'
)
if
self
.
axis
is
None
and
x
.
ndim
>
1
:
if
self
.
axis
>=
x
.
ndim
:
out_type
=
CudaNdarrayType
(
broadcastable
=
(
False
,),
dtype
=
x
.
dtype
)(
)
raise
ValueError
(
'axis(={1}) out of bounds'
.
format
(
self
.
axis
)
)
return
theano
.
Apply
(
self
,
[
x
],
[
out_type
])
return
theano
.
Apply
(
self
,
[
x
],
[
x
.
type
()
])
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
node_
=
copy
.
copy
(
node
)
node_
=
copy
.
copy
(
node
)
...
@@ -55,7 +61,6 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -55,7 +61,6 @@ class GpuCumsum(CumsumOp, GpuOp):
return
()
return
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
axis
=
self
.
axis
return
"""
return
"""
__device__
__device__
void k_reductionPhase_
%(nodename)
s(float* partialCumSum) {
void k_reductionPhase_
%(nodename)
s(float* partialCumSum) {
...
@@ -244,7 +249,7 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -244,7 +249,7 @@ class GpuCumsum(CumsumOp, GpuOp):
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
x
,
=
inames
x
,
=
inames
z
,
=
onames
z
,
=
onames
axis
=
self
.
axis
axis
=
self
.
axis
if
self
.
axis
is
not
None
else
0
fail
=
sub
[
'fail'
]
fail
=
sub
[
'fail'
]
sub
=
sub
.
copy
()
sub
=
sub
.
copy
()
...
@@ -257,89 +262,63 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -257,89 +262,63 @@ class GpuCumsum(CumsumOp, GpuOp):
"related to the selected GPU."
)
"related to the selected GPU."
)
sub
.
update
(
locals
())
sub
.
update
(
locals
())
if
self
.
axis
is
None
or
(
self
.
axis
==
0
and
node
.
inputs
[
0
]
.
ndim
==
1
):
code
=
"""
code
=
"""
const int* shape = CudaNdarray_HOST_DIMS(
%(x)
s);
const int* shape = CudaNdarray_HOST_DIMS(
%(x)
s);
bool needAllocation = !
%(z)
s || CudaNdarray_NDIM(
%(x)
s) != CudaNdarray_NDIM(
%(z)
s);
if(! (
%(z)
s && CudaNdarray_HOST_DIMS(
%(z)
s)[0] == shape[0]) ) {
Py_XDECREF(
%(z)
s);
%(z)
s = (CudaNdarray*) CudaNdarray_NewDims(1, shape);
}
if (!
%(z)
s) {
%(fail)
s;
}
{ // Namespace for kernel calls //
// If output is already allocated, check if its shape matches the input's one.
cumSum_
%(nodename)
s(
%(x)
s,
%(z)
s,
%(max_threads_dim0)
s, 0,
%(max_grid_size1)
s);
if (!needAllocation) {
for (int i= 0; i < CudaNdarray_NDIM(
%(x)
s); ++i) {
cudaError_t sts = cudaGetLastError();
if (CudaNdarray_HOST_DIMS(
%(x)
s)[i] == CudaNdarray_HOST_DIMS(
%(z)
s)[i]) {
if (cudaSuccess != sts)
needAllocation = true;
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error:
%%
s:
%%
s.
\\
n",
"cumSum_1D_
%(nodename)
s",
cudaGetErrorString(sts));
%(fail)
s;
}
}
"""
%
locals
()
elif
node
.
inputs
[
0
]
.
ndim
==
2
:
code
=
"""
const int* shape = CudaNdarray_HOST_DIMS(
%(x)
s);
bool needAllocation = !
%(z)
s || CudaNdarray_NDIM(
%(x)
s) != CudaNdarray_NDIM(
%(z)
s);
// If output is already allocated, check if its shape matches the input's one.
if (!needAllocation) {
for (int i= 0; i < CudaNdarray_NDIM(
%(x)
s); ++i) {
if (CudaNdarray_HOST_DIMS(
%(x)
s)[i] == CudaNdarray_HOST_DIMS(
%(z)
s)[i]) {
needAllocation = true;
}
}
}
}
}
}
if (needAllocation){
if (needAllocation){
Py_XDECREF(
%(z)
s);
Py_XDECREF(
%(z)
s);
%(z)
s = (CudaNdarray*) CudaNdarray_NewDims(CudaNdarray_NDIM(
%(x)
s), shape);
%(z)
s = (CudaNdarray*) CudaNdarray_NewDims(CudaNdarray_NDIM(
%(x)
s), shape);
}
}
if (!
%(z)
s) {
if (!
%(z)
s) {
%(fail)
s;
%(fail)
s;
}
}
{ // Namespace for kernel calls //
{ // Namespace for kernel calls //
cumSum_
%(nodename)
s(
%(x)
s,
%(z)
s,
%(max_threads_dim0)
s,
%(axis)
s,
%(max_grid_size1)
s);
cumSum_
%(nodename)
s(
%(x)
s,
%(z)
s,
%(max_threads_dim0)
s,
%(axis)
s,
%(max_grid_size1)
s);
cudaError_t sts = cudaGetLastError();
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
if (cudaSuccess != sts)
{
{
PyErr_Format(PyExc_RuntimeError,
PyErr_Format(PyExc_RuntimeError,
"Cuda error:
%%
s:
%%
s.
\\
n",
"Cuda error:
%%
s:
%%
s.
\\
n",
"cumSum_2D_axis1_
%(nodename)
s",
"cumSum_
%(nodename)
s",
cudaGetErrorString(sts));
cudaGetErrorString(sts));
%(fail)
s;
%(fail)
s;
}
}
}
"""
%
locals
()
}
else
:
"""
%
locals
()
raise
NotImplementedError
(
'Only 1D case and 2D (axis=1) are supported right now!'
)
return
code
return
code
def
gpu_cumsum
(
x
,
axis
=
None
):
return
GpuCumsum
(
axis
)(
x
)
from
theano.sandbox.cuda
import
GpuFlatten
from
theano.sandbox.cuda
import
GpuFlatten
@local_optimizer
([
CumsumOp
])
@local_optimizer
([
CumsumOp
])
def
use_gpu_cumsum
(
node
):
def
use_gpu_cumsum
(
node
):
if
type
(
node
.
op
)
is
CumsumOp
and
node
.
inputs
[
0
]
.
dtype
==
'float32'
:
if
node
.
inputs
[
0
]
.
ndim
>
GpuCumsum
.
SUPPORTED_NDIMS
:
return
None
if
type
(
node
.
op
)
is
CumsumOp
and
node
.
inputs
[
0
]
.
dtype
==
'float32'
:
x
=
gpu_from_host
(
node
.
inputs
[
0
])
x
=
gpu_from_host
(
node
.
inputs
[
0
])
if
node
.
op
.
axis
is
None
and
x
.
ndim
>
1
:
axis
=
node
.
op
.
axis
if
axis
is
None
and
x
.
ndim
>
1
:
x
=
GpuFlatten
()(
x
)
x
=
GpuFlatten
()(
x
)
return
[
host_from_gpu
(
gpu_cumsum
(
x
,
axis
=
node
.
op
.
axis
))]
# ``gpu_cumsum`` assume array has been flattened if needed.
if
axis
is
None
:
axis
=
0
return
[
host_from_gpu
(
GpuCumsum
(
axis
)(
x
))]
if
cuda_available
:
if
cuda_available
:
register_gpu_opt
()(
use_gpu_cumsum
)
register_gpu_opt
()(
use_gpu_cumsum
)
theano/sandbox/cuda/tests/test_extra_ops.py
浏览文件 @
8ce2474e
...
@@ -17,157 +17,75 @@ from theano import tensor as T
...
@@ -17,157 +17,75 @@ from theano import tensor as T
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
from
theano
import
config
from
theano
import
config
from
theano.tensor.extra_ops
import
cumsum
,
diff
from
theano.tensor.extra_ops
import
cumsum
from
mlpython.misc.utils
import
Timer
class
TestGpuCumsum
(
theano
.
tensor
.
tests
.
test_extra_ops
.
TestCumsumOp
):
class
TestGpuCumsum
(
theano
.
tensor
.
tests
.
test_extra_ops
.
TestCumsumOp
):
mode
=
mode_with_gpu
mode
=
mode_with_gpu
op
=
GpuCumsum
op
=
GpuCumsum
dtypes
=
[
'float32'
]
dtypes
=
[
'float32'
]
def
test_benchmark_1D_vs_2D
(
self
):
def
setUp
(
self
):
print
"
\n
Benchmark:"
super
(
TestGpuCumsum
,
self
)
.
setUp
()
from
theano
import
sandbox
,
Out
import
time
vlen
=
40
*
1024
*
2048
# 10 x # cores x # threads per core
iters
=
25
x
=
theano
.
shared
(
np
.
ones
((
vlen
,),
dtype
=
config
.
floatX
),
borrow
=
False
)
res
=
Out
(
sandbox
.
cuda
.
basic_ops
.
gpu_from_host
(
cumsum
(
x
)),
borrow
=
True
)
f
=
theano
.
function
([],
res
)
print
f
.
maker
.
fgraph
.
toposort
()
t0
=
time
.
time
()
for
i
in
xrange
(
iters
):
r
=
f
()
t1
=
time
.
time
()
print
'Looping
%
d times took'
%
iters
,
t1
-
t0
,
'seconds'
print
'Result is'
,
r
print
'Numpy result is'
,
np
.
asarray
(
r
)
# x = theano.shared(np.ones((1,vlen), dtype=config.floatX), borrow=True)
# f = theano.function([], Out(sandbox.cuda.basic_ops.gpu_from_host(cumsum(x,axis=1)), borrow=True))
# print f.maker.fgraph.toposort()
# Fetch some useful properties on the device
# t0 = time.time()
cuda
=
theano
.
sandbox
.
cuda
# for i in xrange(iters):
device_id
=
cuda
.
use
.
device_number
# r = f()
cuda_ndarray
=
theano
.
sandbox
.
cuda
.
cuda_ndarray
.
cuda_ndarray
# t1 = time.time()
prop
=
cuda_ndarray
.
device_properties
(
device_id
)
self
.
max_threads_dim0
=
prop
[
'maxThreadsDim0'
]
# print 'Looping %d times took' % iters, t1 - t0, 'seconds'
self
.
max_grid_size1
=
prop
[
'maxGridSize1'
]
# print 'Result is', r
# print 'Numpy result is', np.asarray(r)
# print 'Used the', config.device
def
test_GpuCumsum1D
(
self
):
block_max_size
=
self
.
max_threads_dim0
*
2
def
test_GpuCumsum
(
self
):
### Test 1D case ###
x
=
T
.
vector
(
'x'
)
x
=
T
.
vector
(
'x'
)
f
=
theano
.
function
([
x
],
cumsum
(
x
))
f
=
theano
.
function
([
x
],
cumsum
(
x
))
# Even number of elements
a
=
np
.
random
.
random
((
18
,))
.
astype
(
config
.
floatX
)
print
f
(
a
)
print
np
.
cumsum
(
a
)
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
# Odd number of elements
a
=
np
.
random
.
random
((
7
,))
.
astype
(
config
.
floatX
)
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
# Use multiple GPU threadblocks
# Extensive testing for the first 1k sizes
a
=
np
.
random
.
random
((
2048
+
2
,))
.
astype
(
config
.
floatX
)
a
=
np
.
ones
((
int
(
1e3
),),
dtype
=
config
.
floatX
)
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
for
i
in
xrange
(
a
.
shape
[
0
]):
assert
np
.
allclose
(
np
.
cumsum
(
a
[:
i
]),
f
(
a
[:
i
]))
# Use multiple GPU threadblocks
# Use multiple GPU threadblocks
a
=
np
.
random
.
random
((
2048
*
75
+
2
,))
.
astype
(
config
.
floatX
)
a
=
np
.
random
.
random
((
block_max_size
+
2
,))
.
astype
(
config
.
floatX
)
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
# Use
multiple GPU gridblocks
# Use
recursive cumsum
a
=
np
.
ones
((
2048
*
2048
+
2
,))
.
astype
(
config
.
floatX
)
a
=
np
.
ones
((
block_max_size
*
(
block_max_size
+
1
)
+
2
,))
.
astype
(
config
.
floatX
)
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
def
test_GpuCumsum2D
(
self
):
block_max_size
=
self
.
max_threads_dim0
*
2
# Extensive testing
for
i
in
xrange
(
int
(
1e3
)
*
5
):
a
=
np
.
ones
((
i
,),
dtype
=
config
.
floatX
)
fa
=
f
(
a
)
npa
=
np
.
cumsum
(
a
)
if
not
np
.
allclose
(
npa
,
fa
):
print
i
,
np
.
allclose
(
npa
,
fa
)
# Test axis=None
print
fa
print
npa
assert
False
if
i
%
1000
==
0
:
print
i
#for axis in xrange(2):
for
axis
in
xrange
(
2
):
for
axis
in
xrange
(
2
):
### Test 2D case - axis=1 ###
x
=
T
.
matrix
(
'x'
)
x
=
T
.
matrix
(
'x'
)
f
=
theano
.
function
([
x
],
cumsum
(
x
,
axis
=
axis
))
f
=
theano
.
function
([
x
],
cumsum
(
x
,
axis
=
axis
))
# Even number of elements
print
"
\n
# Even number of elements (axis={0})"
.
format
(
axis
)
a
=
np
.
random
.
random
((
18
,
18
))
.
astype
(
config
.
floatX
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
# Odd number of elements
# Extensive testing for the first 1k sizes
print
"
\n
# Odd number of elements (axis={0})"
.
format
(
axis
)
a_shape
=
[
11
,
11
]
a
=
np
.
random
.
random
((
21
,
21
))
.
astype
(
config
.
floatX
)
a_shape
[
axis
]
=
int
(
1e3
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
a
=
np
.
ones
(
a_shape
,
dtype
=
config
.
floatX
)
slices
=
[
slice
(
None
),
slice
(
None
)]
# Use two GPU threadblocks
for
i
in
xrange
(
a
.
shape
[
axis
]):
print
"
\n
# Use two GPU threadblocks (axis={0})"
.
format
(
axis
)
slices
[
axis
]
=
slice
(
i
)
a
=
np
.
random
.
random
((
2048
+
2
,
2048
+
2
))
.
astype
(
config
.
floatX
)
fa
=
f
(
a
[
slices
])
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
npa
=
np
.
cumsum
(
a
[
slices
],
axis
=
axis
)
assert
np
.
allclose
(
npa
,
fa
)
# Use multiple GPU threadblocks
# Use multiple GPU threadblocks
print
"
\n
# Use multiple GPU threadblocks (axis={0})"
.
format
(
axis
)
a_shape
=
[
11
,
11
]
a
=
np
.
ones
((
10
,
2048
*
75
+
3
))
.
astype
(
config
.
floatX
)
a_shape
[
axis
]
=
block_max_size
+
2
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
a
=
np
.
ones
(
a_shape
,
dtype
=
config
.
floatX
)
a
=
np
.
ones
((
2048
*
75
+
3
,
10
))
.
astype
(
config
.
floatX
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
# Use multiple GPU gridblocks
# Use multiple GPU gridblocks
print
"
\n
# Use multiple GPU gridblocks (axis={0})"
.
format
(
axis
)
a_shape
=
[
11
,
11
]
a
=
np
.
ones
((
11
,
2048
*
2048
+
3
))
.
astype
(
config
.
floatX
)
a_shape
[
1
-
axis
]
=
self
.
max_grid_size1
+
1
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
a
=
np
.
ones
(
a_shape
,
dtype
=
config
.
floatX
)
a
=
np
.
ones
((
2048
*
2048
+
3
,
11
))
.
astype
(
config
.
floatX
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
# Extensive testing for the first 10k sizes
# Use recursive cumsum
for
i
in
xrange
(
int
(
1e3
)
*
5
):
a_shape
=
[
11
,
11
]
a
=
np
.
ones
((
11
,
i
),
dtype
=
config
.
floatX
)
a_shape
[
axis
]
=
block_max_size
*
(
block_max_size
+
1
)
+
2
fa
=
f
(
a
)
a
=
np
.
ones
(
a_shape
,
dtype
=
config
.
floatX
)
npa
=
np
.
cumsum
(
a
,
axis
=
axis
)
assert
np
.
allclose
(
np
.
cumsum
(
a
,
axis
=
axis
),
f
(
a
))
if
not
np
.
allclose
(
npa
,
fa
):
print
i
,
np
.
allclose
(
npa
,
fa
)
# Test axis=None
print
fa
print
npa
assert
False
a
=
np
.
ones
((
i
,
11
),
dtype
=
config
.
floatX
)
fa
=
f
(
a
)
npa
=
np
.
cumsum
(
a
,
axis
=
axis
)
if
not
np
.
allclose
(
npa
,
fa
):
print
i
,
np
.
allclose
(
npa
,
fa
)
# Test axis=None
print
fa
print
npa
assert
False
if
i
%
1000
==
0
:
print
i
\ No newline at end of file
theano/tensor/extra_ops.py
浏览文件 @
8ce2474e
...
@@ -28,6 +28,8 @@ class CumsumOp(theano.Op):
...
@@ -28,6 +28,8 @@ class CumsumOp(theano.Op):
if
self
.
axis
is
None
:
if
self
.
axis
is
None
:
out_type
=
theano
.
tensor
.
vector
(
dtype
=
x
.
dtype
)
# Flatten
out_type
=
theano
.
tensor
.
vector
(
dtype
=
x
.
dtype
)
# Flatten
elif
self
.
axis
>=
x
.
ndim
:
raise
ValueError
(
'axis(={0}) out of bounds'
.
format
(
self
.
axis
))
return
theano
.
Apply
(
self
,
[
x
],
[
out_type
])
return
theano
.
Apply
(
self
,
[
x
],
[
out_type
])
...
@@ -148,6 +150,8 @@ class CumprodOp(theano.Op):
...
@@ -148,6 +150,8 @@ class CumprodOp(theano.Op):
if
self
.
axis
is
None
:
if
self
.
axis
is
None
:
out_type
=
theano
.
tensor
.
vector
(
dtype
=
x
.
dtype
)
# Flatten
out_type
=
theano
.
tensor
.
vector
(
dtype
=
x
.
dtype
)
# Flatten
elif
self
.
axis
>=
x
.
ndim
:
raise
ValueError
(
'axis(={0}) out of bounds'
.
format
(
self
.
axis
))
return
theano
.
Apply
(
self
,
[
x
],
[
out_type
])
return
theano
.
Apply
(
self
,
[
x
],
[
out_type
])
...
...
theano/tensor/tests/test_extra_ops.py
浏览文件 @
8ce2474e
...
@@ -28,6 +28,9 @@ class TestCumsumOp(utt.InferShapeTester):
...
@@ -28,6 +28,9 @@ class TestCumsumOp(utt.InferShapeTester):
x
=
T
.
tensor3
(
'x'
)
x
=
T
.
tensor3
(
'x'
)
a
=
np
.
random
.
random
((
3
,
5
,
2
))
.
astype
(
config
.
floatX
)
a
=
np
.
random
.
random
((
3
,
5
,
2
))
.
astype
(
config
.
floatX
)
# Test axis out of bounds
self
.
assertRaises
(
ValueError
,
cumsum
,
x
,
axis
=
4
)
f
=
theano
.
function
([
x
],
cumsum
(
x
))
f
=
theano
.
function
([
x
],
cumsum
(
x
))
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
# Test axis=None
assert
np
.
allclose
(
np
.
cumsum
(
a
),
f
(
a
))
# Test axis=None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论