Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1f4f4ff6
提交
1f4f4ff6
authored
9月 19, 2014
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Outsource the creation of the cudnn convolution descriptor.
上级
1a1dbe60
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
195 行增加
和
122 行删除
+195
-122
dnn.py
theano/sandbox/cuda/dnn.py
+195
-122
没有找到文件。
theano/sandbox/cuda/dnn.py
浏览文件 @
1f4f4ff6
...
@@ -2,8 +2,8 @@ import copy
...
@@ -2,8 +2,8 @@ import copy
import
os
import
os
import
theano
import
theano
from
theano
import
Apply
from
theano
import
Apply
,
tensor
from
theano
import
tensor
from
theano
.gof.type
import
CDataType
from
theano.compat.six
import
StringIO
from
theano.compat.six
import
StringIO
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda
import
GpuOp
from
theano.sandbox.cuda
import
GpuOp
...
@@ -12,6 +12,7 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
...
@@ -12,6 +12,7 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
from
theano.sandbox.cuda.blas
import
GpuConv
from
theano.sandbox.cuda.blas
import
GpuConv
from
theano.compat
import
PY3
from
theano.compat
import
PY3
from
theano.sandbox.cuda.nvcc_compiler
import
NVCC_compiler
class
DnnBase
(
GpuOp
):
class
DnnBase
(
GpuOp
):
"""
"""
...
@@ -46,24 +47,108 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
...
@@ -46,24 +47,108 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
}"""
%
(
error_out
,)]
}"""
%
(
error_out
,)]
class
GpuDnnConv
Base
(
DnnBase
):
class
GpuDnnConv
Desc
(
GpuOp
):
__props__
=
(
'border_mode'
,
'subsample'
,
'conv_mode'
)
__props__
=
(
'border_mode'
,
'subsample'
,
'conv_mode'
)
def
c_headers
(
self
):
return
[
'cudnn.h'
,
'cudnn_helper.h'
]
def
c_header_dirs
(
self
):
return
[
os
.
path
.
dirname
(
__file__
)]
def
c_libraries
(
self
):
return
[
'cudnn'
]
def
c_compiler
(
self
):
return
NVCC_compiler
def
__init__
(
self
,
border_mode
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
):
def
__init__
(
self
,
border_mode
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
):
assert
border_mode
in
(
'valid'
,
'full'
)
assert
border_mode
in
(
'valid'
,
'full'
)
self
.
border_mode
=
border_mode
self
.
border_mode
=
border_mode
assert
len
(
subsample
)
==
2
self
.
subsample
=
subsample
self
.
subsample
=
subsample
assert
conv_mode
in
(
'conv'
,
'cross'
)
assert
conv_mode
in
(
'conv'
,
'cross'
)
self
.
conv_mode
=
conv_mode
self
.
conv_mode
=
conv_mode
def
__setstate__
(
self
,
props
):
def
make_node
(
self
,
img_shape
,
kern_shape
):
self
.
__dict__
.
update
(
props
)
if
img_shape
.
type
.
ndim
!=
1
and
img_shape
.
type
.
dtype
!=
numpy
.
int64
:
if
not
hasattr
(
self
,
'conv_mode'
):
raise
TypeError
(
'img must be 1D shape tensor'
)
self
.
conv_mode
=
'conv'
if
kern_shape
.
type
.
ndim
!=
1
and
kern_shape
.
type
.
dtype
!=
numpy
.
int64
:
if
not
hasattr
(
self
,
'subsample'
):
raise
TypeError
(
'kern must be 1D shape tensor'
)
self
.
subsample
=
(
1
,
1
)
return
Apply
(
self
,
[
img_shape
,
kern_shape
],
[
CDataType
(
"cudnnConvolutionDescriptor_t"
)()])
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
img_shape
,
kern_shape
=
inputs
desc
,
=
outputs
if
self
.
border_mode
==
"valid"
:
bmode
=
1
else
:
assert
self
.
border_mode
==
"full"
bmode
=
0
if
self
.
conv_mode
==
'conv'
:
conv_flag
=
'CUDNN_CONVOLUTION'
else
:
conv_flag
=
'CUDNN_CROSS_CORRELATION'
return
"""
{
cudnnStatus_t err;
int pad_h
%(name)
s;
int pad_w
%(name)
s;
if ((err = cudnnCreateConvolutionDescriptor(&
%(desc)
s)) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not allocate convolution "
"descriptor:
%%
s", cudnnGetErrorString(err));
%(fail)
s
}
if (
%(bmode)
d == 1) {
pad_h
%(name)
s = 0;
pad_w
%(name)
s = 0;
} else if (
%(bmode)
d == 0) {
pad_h
%(name)
s = *(npy_int64 *)PyArray_GETPTR1(
%(kern_shape)
s, 2) - 1;
pad_w
%(name)
s = *(npy_int64 *)PyArray_GETPTR1(
%(kern_shape)
s, 3) - 1;
} else {
PyErr_SetString(PyExc_ValueError, "bad border mode");
%(fail)
s
}
err = cudnnSetConvolutionDescriptorEx(
%(desc)
s,
*(npy_int64 *)PyArray_GETPTR1(
%(img_shape)
s, 0),
*(npy_int64 *)PyArray_GETPTR1(
%(img_shape)
s, 1),
*(npy_int64 *)PyArray_GETPTR1(
%(img_shape)
s, 2),
*(npy_int64 *)PyArray_GETPTR1(
%(img_shape)
s, 3),
*(npy_int64 *)PyArray_GETPTR1(
%(kern_shape)
s, 0),
*(npy_int64 *)PyArray_GETPTR1(
%(kern_shape)
s, 2),
*(npy_int64 *)PyArray_GETPTR1(
%(kern_shape)
s, 3),
pad_h
%(name)
s,
pad_w
%(name)
s,
%(subsx)
d,
%(subsy)
d, 1, 1,
%(conv_flag)
s
);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "could not set op descriptor:
%%
s",
cudnnGetErrorString(err));
%(fail)
s
}
}
"""
%
dict
(
name
=
name
,
img_shape
=
img_shape
,
kern_shape
=
kern_shape
,
desc
=
desc
,
bmode
=
bmode
,
conv_flag
=
conv_flag
,
fail
=
sub
[
'fail'
],
subsx
=
self
.
subsample
[
0
],
subsy
=
self
.
subsample
[
1
])
def
c_code_cache_version
(
self
):
return
(
1
,)
def
make_node
(
self
,
img
,
kern
):
class
GpuDnnConvBase
(
DnnBase
):
__props__
=
()
def
make_node
(
self
,
img
,
kern
,
desc
):
if
img
.
type
.
ndim
!=
4
:
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be 4D tensor'
)
raise
TypeError
(
'img must be 4D tensor'
)
if
kern
.
type
.
ndim
!=
4
:
if
kern
.
type
.
ndim
!=
4
:
...
@@ -73,50 +158,45 @@ class GpuDnnConvBase(DnnBase):
...
@@ -73,50 +158,45 @@ class GpuDnnConvBase(DnnBase):
kern
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
0
],
False
,
False
)
False
,
False
)
return
Apply
(
self
,
[
img
,
kern
],
[
CudaNdarrayType
(
broadcastable
)()])
return
Apply
(
self
,
[
img
,
kern
,
desc
],
[
CudaNdarrayType
(
broadcastable
)()])
def
c_support_code_struct
(
self
,
node
,
struct_id
):
def
c_support_code_struct
(
self
,
node
,
struct_id
):
types
=
[
'cudnn'
+
d
.
capitalize
()
+
'Descriptor_t'
return
"""
for
d
in
self
.
descriptors
]
cudnnTensor4dDescriptor_t input
%(id)
d;
elems
=
[
t
+
' param
%
d_
%
d;'
%
(
i
,
struct_id
)
cudnnTensor4dDescriptor_t output
%(id)
d;
for
i
,
t
in
enumerate
(
types
)]
cudnnFilterDescriptor_t kerns
%(id)
d;
return
(
"cudnnConvolutionDescriptor_t op
%
d;
\n
"
%
(
struct_id
,)
+
"""
%
dict
(
id
=
struct_id
)
'
\n
'
.
join
(
elems
))
def
c_init_code_struct
(
self
,
node
,
struct_id
,
sub
):
def
c_init_code_struct
(
self
,
node
,
struct_id
,
sub
):
vnames
=
[
'param
%
d_
%
d'
%
(
i
,
struct_id
)
return
"""
for
i
,
t
in
enumerate
(
self
.
descriptors
)]
cudnnStatus_t err
%(id)
d;
inits
=
[
vname
+
'= NULL;'
for
vname
in
vnames
]
input
%(id)
d = NULL;
creates
=
[]
output
%(id)
d = NULL;
for
d
,
var
in
zip
(
self
.
descriptors
,
vnames
):
kerns
%(id)
d = NULL;
creates
.
append
(
"""
if ((err
%(id)
d = cudnnCreateTensor4dDescriptor(&input
%(id)
d)) != CUDNN_STATUS_SUCCESS) {
if ((err
%(id)
d = cudnnCreate
%(d)
sDescriptor(&
%(var)
s)) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not allocate tensor4d descriptor "
PyErr_Format(PyExc_MemoryError, "could not allocate tensor4d descriptor "
"(inp):
%%
s", cudnnGetErrorString(err
%(id)
d));
"(inp):
%%
s", cudnnGetErrorString(err
%(id)
d));
%(fail)
s
%(fail)
s
}
}
"""
%
dict
(
id
=
struct_id
,
d
=
d
.
capitalize
(),
var
=
var
,
fail
=
sub
[
'fail'
]))
if ((err
%(id)
d = cudnnCreateTensor4dDescriptor(&output
%(id)
d)) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not allocate tensor4d descriptor "
return
"""
"(out):
%%
s", cudnnGetErrorString(err
%(id)
d));
%(init)
s
cudnnStatus_t err
%(id)
d;
%(create)
s
if ((err
%(id)
d = cudnnCreateConvolutionDescriptor(&op
%(id)
d)) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not allocate convolution "
"descriptor:
%%
s", cudnnGetErrorString(err
%(id)
d));
%(fail)
s
%(fail)
s
}
}
"""
%
dict
(
id
=
struct_id
,
fail
=
sub
[
'fail'
],
init
=
'
\n
'
.
join
(
inits
),
if ((err
%(id)
d = cudnnCreateFilterDescriptor(&kerns
%(id)
d)) != CUDNN_STATUS_SUCCESS) {
create
=
'
\n
'
.
join
(
creates
))
PyErr_Format(PyExc_MemoryError, "could not allocate filter descriptor:
%%
s",
cudnnGetErrorString(err
%(id)
d));
%(fail)
s
}
"""
%
dict
(
id
=
struct_id
,
fail
=
sub
[
'fail'
])
def
c_cleanup_code_struct
(
self
,
node
,
struct_id
):
def
c_cleanup_code_struct
(
self
,
node
,
struct_id
):
cleanups
=
[
'cudnnDestroy
%
sDescriptor(param
%
d_
%
d);'
%
(
d
.
capitalize
(),
i
,
struct_id
)
for
i
,
d
in
enumerate
(
self
.
descriptors
)]
return
"""
return
"""
%(cleanup)
s
cudnnDestroyTensor4dDescriptor(input
%(id)
d);
cudnnDestroyConvolutionDescriptor(op
%(id)
d);
cudnnDestroyTensor4dDescriptor(output
%(id)
d);
"""
%
dict
(
id
=
struct_id
,
cleanup
=
'
\n
'
.
join
(
cleanups
))
cudnnDestroyFilterDescriptor(kerns
%(id)
d);
"""
%
dict
(
id
=
struct_id
)
def
c_set_tensor4d
(
self
,
var
,
desc
,
err
,
fail
):
def
c_set_tensor4d
(
self
,
var
,
desc
,
err
,
fail
):
return
"""
return
"""
...
@@ -155,25 +235,11 @@ if (%(err)s != CUDNN_STATUS_SUCCESS) {
...
@@ -155,25 +235,11 @@ if (%(err)s != CUDNN_STATUS_SUCCESS) {
"""
%
dict
(
var
=
var
,
desc
=
desc
,
err
=
err
,
fail
=
fail
)
"""
%
dict
(
var
=
var
,
desc
=
desc
,
err
=
err
,
fail
=
fail
)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
param0
,
param1
=
inputs
desc
=
inputs
[
2
]
out
,
=
outputs
out
,
=
outputs
if
self
.
border_mode
==
"valid"
:
bmode
=
1
else
:
assert
self
.
border_mode
==
"full"
bmode
=
0
if
self
.
conv_mode
==
'conv'
:
conv_flag
=
'CUDNN_CONVOLUTION'
else
:
conv_flag
=
'CUDNN_CROSS_CORRELATION'
vnames
=
[
'param
%
d_
%
d'
%
(
i
,
sub
[
'struct_id'
])
for
i
,
t
in
enumerate
(
self
.
descriptors
)]
checks
=
[]
checks
=
[]
for
v
in
(
param0
,
param1
)
:
for
v
in
inputs
[:
2
]
:
checks
.
append
(
"""
checks
.
append
(
"""
if (!CudaNdarray_is_c_contiguous(
%
s)) {
if (!CudaNdarray_is_c_contiguous(
%
s)) {
PyErr_SetString(PyExc_ValueError, "Only contiguous inputs are supported.");
PyErr_SetString(PyExc_ValueError, "Only contiguous inputs are supported.");
...
@@ -182,70 +248,57 @@ if (!CudaNdarray_is_c_contiguous(%s)) {
...
@@ -182,70 +248,57 @@ if (!CudaNdarray_is_c_contiguous(%s)) {
"""
%
(
v
,
sub
[
'fail'
]))
"""
%
(
v
,
sub
[
'fail'
]))
sets
=
[]
sets
=
[]
for
p
,
v
,
d
in
zip
((
param0
,
param1
),
vnames
[:
-
1
],
for
p
,
v
,
d
in
zip
(
inputs
[:
2
],
self
.
conv_inputs
,
self
.
conv_types
[:
2
]):
self
.
descriptors
[:
-
1
]):
sets
.
append
(
getattr
(
self
,
'c_set_'
+
d
)(
p
,
v
+
str
(
sub
[
'struct_id'
]),
sets
.
append
(
getattr
(
self
,
'c_set_'
+
d
)(
p
,
v
,
'err'
+
name
,
'err'
+
name
,
sub
[
'fail'
]))
sub
[
'fail'
]))
set_out
=
getattr
(
self
,
'c_set_'
+
self
.
descriptors
[
-
1
])(
set_out
=
getattr
(
self
,
'c_set_'
+
self
.
conv_types
[
2
])(
out
,
vnames
[
-
1
],
'err'
+
name
,
sub
[
'fail'
])
out
,
self
.
conv_output
+
str
(
sub
[
'struct_id'
]),
'err'
+
name
,
sub
[
'fail'
])
return
"""
return
"""
cudnnStatus_t err
%(name)
s;
cudnnStatus_t err
%(name)
s;
int pad_w
%(name)
s;
int pad_h
%(name)
s;
%(checks)
s
%(checks)
s
%(sets)
s
%(sets)
s
if (
%(bmode)
d == 1) {
pad_h
%(name)
s = 0;
pad_w
%(name)
s = 0;
} else if (
%(bmode)
d == 0) {
pad_h
%(name)
s = CudaNdarray_HOST_DIMS(
%(param1)
s)[2] - 1;
pad_w
%(name)
s = CudaNdarray_HOST_DIMS(
%(param1)
s)[3] - 1;
} else {
PyErr_SetString(PyExc_ValueError, "bad border mode");
%(fail)
s
}
err
%(name)
s = cudnnSetConvolutionDescriptor(
op
%(id)
d, param0_
%(id)
d, param1_
%(id)
d,
pad_h
%(name)
s,
pad_w
%(name)
s,
%(subsx)
d,
%(subsy)
d, 1, 1,
%(conv_flag)
s
);
if (err
%(name)
s != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "could not set op descriptor:
%%
s",
cudnnGetErrorString(err
%(name)
s));
%(fail)
s
}
{
{
int out_dims[4];
int out_dims[4];
err
%(name)
s = cudnnGetOutputTensor4dDim(
err
%(name)
s = cudnnGetOutputTensor4dDim(
op
%(id)
d,
%(path)
s,
%(desc)
s,
%(path)
s,
&out_dims[0], &out_dims[1],
&out_dims[0], &out_dims[1],
&out_dims[2], &out_dims[3]
&out_dims[2], &out_dims[3]
);
);
if (err
%(name)
s != CUDNN_STATUS_SUCCESS) {
if (err
%(name)
s != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "could not get output sizes:
%%
s",
PyErr_Format(PyExc_RuntimeError, "could not get output sizes:
%%
s",
cudnnGetErrorString(err
%(name)
s));
cudnnGetErrorString(err
%(name)
s));
%(fail)
s
%(fail)
s
}
}
if (CudaNdarray_prep_output(&
%(out)
s, 4, out_dims) != 0) {
// workaround for cudnn R1 bug
%(fail)
s
if (
%(path)
s == CUDNN_CONVOLUTION_WEIGHT_GRAD &&
}
(out_dims[0] != CudaNdarray_HOST_DIMS(
%(input2)
s)[1] ||
out_dims[1] != CudaNdarray_HOST_DIMS(
%(input1)
s)[1])) {
out_dims[0] = CudaNdarray_HOST_DIMS(
%(input2)
s)[1];
out_dims[1] = CudaNdarray_HOST_DIMS(
%(input1)
s)[1];
// This is a horrible hack that is unfortulately necessary
int *dd = (int *)
%(desc)
s;
out_dims[2] = dd[5];
out_dims[3] = dd[6];
}
if (CudaNdarray_prep_output(&
%(out)
s, 4, out_dims) != 0) {
%(fail)
s
}
}
}
%(set_out)
s
%(set_out)
s
err
%(name)
s =
%(method)
s(
err
%(name)
s =
%(method)
s(
_handle,
_handle,
param0_
%(id)
d, CudaNdarray_DEV_DATA(
%(param0
)
s),
%(input1_desc)
s, CudaNdarray_DEV_DATA(
%(input1
)
s),
param1_
%(id)
d, CudaNdarray_DEV_DATA(
%(param1
)
s),
%(input2_desc)
s, CudaNdarray_DEV_DATA(
%(input2
)
s),
op
%(id)
d
,
%(desc)
s
,
param2_
%(id)
d
, CudaNdarray_DEV_DATA(
%(out)
s),
%(output_desc)
s
, CudaNdarray_DEV_DATA(
%(out)
s),
CUDNN_RESULT_NO_ACCUMULATE
CUDNN_RESULT_NO_ACCUMULATE
);
);
if (err
%(name)
s != CUDNN_STATUS_SUCCESS) {
if (err
%(name)
s != CUDNN_STATUS_SUCCESS) {
...
@@ -253,41 +306,53 @@ if (err%(name)s != CUDNN_STATUS_SUCCESS) {
...
@@ -253,41 +306,53 @@ if (err%(name)s != CUDNN_STATUS_SUCCESS) {
cudnnGetErrorString(err
%(name)
s));
cudnnGetErrorString(err
%(name)
s));
%(fail)
s
%(fail)
s
}
}
"""
%
dict
(
param0
=
param0
,
param1
=
param1
,
out
=
out
,
bmode
=
bmode
,
"""
%
dict
(
out
=
out
,
desc
=
desc
,
fail
=
sub
[
'fail'
],
id
=
sub
[
'struct_id'
],
conv_flag
=
conv_flag
,
fail
=
sub
[
'fail'
],
id
=
sub
[
'struct_id'
],
name
=
name
,
checks
=
'
\n
'
.
join
(
checks
),
sets
=
'
\n
'
.
join
(
sets
),
name
=
name
,
checks
=
'
\n
'
.
join
(
checks
),
sets
=
'
\n
'
.
join
(
sets
),
subsx
=
self
.
subsample
[
0
],
subsy
=
self
.
subsample
[
1
],
set_out
=
set_out
,
input1
=
inputs
[
0
],
input2
=
inputs
[
1
],
set_out
=
set_out
,
method
=
self
.
conv_op
,
path
=
self
.
path_flag
)
input1_desc
=
self
.
conv_inputs
[
0
]
+
str
(
sub
[
'struct_id'
]),
input2_desc
=
self
.
conv_inputs
[
1
]
+
str
(
sub
[
'struct_id'
]),
output_desc
=
self
.
conv_output
+
str
(
sub
[
'struct_id'
]),
method
=
self
.
conv_op
,
path
=
self
.
path_flag
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
6
,)
return
(
7
,)
class
GpuDnnConv
(
GpuDnnConvBase
):
class
GpuDnnConv
(
GpuDnnConvBase
):
descriptors
=
(
'tensor4d'
,
'filter'
,
'tensor4d'
)
conv_inputs
=
'input'
,
'kerns'
conv_output
=
'output'
conv_types
=
'tensor4d'
,
'filter'
,
'tensor4d'
conv_op
=
'cudnnConvolutionForward'
path_flag
=
'CUDNN_CONVOLUTION_FWD'
path_flag
=
'CUDNN_CONVOLUTION_FWD'
conv_op
=
'cudnnConvolutionForward'
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
img
,
kerns
=
inp
img
,
kerns
,
desc
=
inp
top
,
=
grads
top
,
=
grads
d_img
=
GpuDnnConvGradI
(
self
.
border_mode
,
self
.
subsample
,
top
=
gpu_contiguous
(
top
)
self
.
conv_mode
)(
kerns
,
top
)
d_
kerns
=
GpuDnnConvGradW
(
self
.
border_mode
,
self
.
subsample
,
d_
img
=
GpuDnnConvGradI
()(
kerns
,
top
,
desc
)
self
.
conv_mode
)(
img
,
top
)
d_kerns
=
GpuDnnConvGradW
()(
img
,
top
,
desc
)
return
d_img
,
d_kerns
return
d_img
,
d_kerns
,
theano
.
gradient
.
DisconnectedType
()()
def
connection_pattern
(
self
,
node
):
# not connected to desc
return
[[
1
],
[
1
],
[
0
]]
class
GpuDnnConvGradW
(
GpuDnnConvBase
):
class
GpuDnnConvGradW
(
GpuDnnConvBase
):
descriptors
=
(
'tensor4d'
,
'tensor4d'
,
'filter'
)
conv_inputs
=
'input'
,
'output'
,
conv_output
=
'kerns'
conv_types
=
'tensor4d'
,
'tensor4d'
,
'filter'
path_flag
=
'CUDNN_CONVOLUTION_WEIGHT_GRAD'
path_flag
=
'CUDNN_CONVOLUTION_WEIGHT_GRAD'
conv_op
=
'cudnnConvolutionBackwardFilter'
conv_op
=
'cudnnConvolutionBackwardFilter'
class
GpuDnnConvGradI
(
GpuDnnConvBase
):
class
GpuDnnConvGradI
(
GpuDnnConvBase
):
descriptors
=
(
'filter'
,
'tensor4d'
,
'tensor4d'
)
conv_inputs
=
'kerns'
,
'output'
,
conv_output
=
'input'
conv_types
=
'filter'
,
'tensor4d'
,
'tensor4d'
path_flag
=
'CUDNN_CONVOLUTION_DATA_GRAD'
path_flag
=
'CUDNN_CONVOLUTION_DATA_GRAD'
conv_op
=
'cudnnConvolutionBackwardData'
conv_op
=
'cudnnConvolutionBackwardData'
...
@@ -295,6 +360,14 @@ class GpuDnnConvGradI(GpuDnnConvBase):
...
@@ -295,6 +360,14 @@ class GpuDnnConvGradI(GpuDnnConvBase):
from
theano.sandbox.cuda.opt
import
(
local_optimizer
,
gpu_contiguous
,
from
theano.sandbox.cuda.opt
import
(
local_optimizer
,
gpu_contiguous
,
gpu_optimizer
)
gpu_optimizer
)
def
dnn_conv
(
img
,
kerns
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
):
img
=
gpu_contiguous
(
img
)
kerns
=
gpu_contiguous
(
kerns
)
desc
=
GpuDnnConvDesc
(
border_mode
=
border_mode
,
subsample
=
subsample
,
conv_mode
=
conv_mode
)(
img
.
shape
,
kerns
.
shape
)
return
GpuDnnConv
()(
img
,
kerns
,
desc
)
@local_optimizer
([
GpuConv
])
@local_optimizer
([
GpuConv
])
def
local_conv_dnn
(
node
):
def
local_conv_dnn
(
node
):
if
isinstance
(
node
.
op
,
GpuConv
):
if
isinstance
(
node
.
op
,
GpuConv
):
...
@@ -303,7 +376,7 @@ def local_conv_dnn(node):
...
@@ -303,7 +376,7 @@ def local_conv_dnn(node):
img
,
kern
=
node
.
inputs
img
,
kern
=
node
.
inputs
border_mode
=
node
.
op
.
border_mode
border_mode
=
node
.
op
.
border_mode
subsample
=
node
.
op
.
subsample
subsample
=
node
.
op
.
subsample
return
[
GpuDnnConv
(
border_mode
,
subsample
)(
gpu_contiguous
(
img
),
return
[
dnn_conv
(
gpu_contiguous
(
img
),
gpu_contiguous
(
kern
),
gpu_contiguous
(
kern
)
)]
border_mode
=
border_mode
,
subsample
=
subsample
)]
gpu_optimizer
.
register
(
"conv_cudnn"
,
local_conv_dnn
,
'cudnn'
)
gpu_optimizer
.
register
(
"conv_cudnn"
,
local_conv_dnn
,
'cudnn'
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论