Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
46e0dbdd
提交
46e0dbdd
authored
8月 16, 2017
作者:
Boris Fomitchev
提交者:
notoraptor
8月 18, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Addressing code review comments
上级
c7e02f24
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
298 行增加
和
282 行删除
+298
-282
cudnn_helper.h
theano/gpuarray/c_code/cudnn_helper.h
+10
-1
dnn_conv_base.c
theano/gpuarray/c_code/dnn_conv_base.c
+202
-0
dnn_conv_find.c
theano/gpuarray/c_code/dnn_conv_find.c
+0
-1
dnn_conv_find.h
theano/gpuarray/c_code/dnn_conv_find.h
+0
-22
dnn_fwd.c
theano/gpuarray/c_code/dnn_fwd.c
+13
-122
dnn_gi.c
theano/gpuarray/c_code/dnn_gi.c
+37
-59
dnn_gw.c
theano/gpuarray/c_code/dnn_gw.c
+33
-74
dnn.py
theano/gpuarray/dnn.py
+3
-3
没有找到文件。
theano/gpuarray/c_code/cudnn_helper.h
浏览文件 @
46e0dbdd
...
@@ -11,6 +11,15 @@ static inline int cudnnGetVersion() {
...
@@ -11,6 +11,15 @@ static inline int cudnnGetVersion() {
}
}
#endif
#endif
#if CUDNN_MAJOR < 7
enum
cudnnMathType_t
{
CUDNN_DEFAULT_MATH
=
0
,
CUDNN_TENSOR_OP_MATH
=
1
};
#endif
/* a common struct for all 3 CUDNN enums */
struct
AlgoRec
{
int
algo
;
cudnnDataType_t
dataType
;
size_t
wsSize
;
cudnnMathType_t
mathType
;
};
#endif
#endif
theano/gpuarray/c_code/dnn_conv_base.c
浏览文件 @
46e0dbdd
...
@@ -50,3 +50,205 @@ if (APPLY_SPECIFIC(output) != NULL)
...
@@ -50,3 +50,205 @@ if (APPLY_SPECIFIC(output) != NULL)
cudnnDestroyTensorDescriptor
(
APPLY_SPECIFIC
(
output
));
cudnnDestroyTensorDescriptor
(
APPLY_SPECIFIC
(
output
));
if
(
APPLY_SPECIFIC
(
kerns
)
!=
NULL
)
if
(
APPLY_SPECIFIC
(
kerns
)
!=
NULL
)
cudnnDestroyFilterDescriptor
(
APPLY_SPECIFIC
(
kerns
));
cudnnDestroyFilterDescriptor
(
APPLY_SPECIFIC
(
kerns
));
#section support_code
#include <sstream>
#include <vector>
#include <string>
#if __cplusplus < 201103L
#include <tr1/unordered_map>
typedef
std
::
tr1
::
unordered_map
<
std
::
string
,
AlgoRec
>
AlgoCache
;
#else
#include <unordered_map>
typedef
std
::
unordered_map
<
std
::
string
,
AlgoRec
>
AlgoCache
;
#endif
#include "pthread.h"
#line 69 "dnn_conv_base.c"
using
std
::
vector
;
using
std
::
string
;
pthread_mutex_t
algoMutex
;
AlgoCache
algoCache
;
static
cudnnStatus_t
checkCudnnStatus
(
cudnnStatus_t
err
)
{
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"CUDNN Error: %s"
,
cudnnGetErrorString
(
err
));
}
return
err
;
}
static
int
c_get_largest_free_block_size
(
PyGpuContextObject
*
c
)
{
size_t
free
=
0
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
return
free
;
}
static
std
::
string
shape
(
int
*
res
,
int
size
)
{
std
::
stringstream
s
;
if
(
size
>
0
)
{
s
<<
res
[
0
];
for
(
int
i
=
1
;
i
<
size
;
++
i
)
s
<<
','
<<
res
[
i
];
}
return
std
::
string
(
s
.
str
().
c_str
());
}
static
std
::
string
shape
(
cudnnTensorDescriptor_t
t
)
{
std
::
vector
<
int
>
res
;
std
::
vector
<
int
>
stride
;
int
nbDims
;
cudnnDataType_t
type
;
checkCudnnStatus
(
cudnnGetTensorNdDescriptor
(
t
,
0
,
&
type
,
&
nbDims
,
0
,
0
));
res
.
resize
(
nbDims
);
stride
.
resize
(
nbDims
);
checkCudnnStatus
(
cudnnGetTensorNdDescriptor
(
t
,
nbDims
,
&
type
,
&
nbDims
,
res
.
data
(),
stride
.
data
()));
return
shape
(
&
res
[
0
],
nbDims
)
+
shape
(
&
stride
[
0
],
nbDims
);
};
static
std
::
string
shape
(
cudnnFilterDescriptor_t
t
,
cudnnDataType_t
*
type
)
{
cudnnTensorFormat_t
format
;
int
sizes
=
8
;
std
::
vector
<
int
>
res
(
sizes
);
int
outDims
;
checkCudnnStatus
(
cudnnGetFilterNdDescriptor
(
t
,
sizes
,
type
,
&
format
,
&
outDims
,
res
.
data
()));
return
shape
(
&
res
[
0
],
outDims
);
};
static
std
::
string
shape
(
cudnnConvolutionDescriptor_t
convDesc
)
{
const
int
maxDim
=
5
;
int
nDim
=
0
;
cudnnConvolutionMode_t
mode
;
cudnnDataType_t
computeType
;
int
padA
[
maxDim
];
int
strideA
[
maxDim
];
int
dilationA
[
maxDim
];
checkCudnnStatus
(
cudnnGetConvolutionNdDescriptor
(
convDesc
,
maxDim
,
&
nDim
,
&
padA
[
0
],
&
strideA
[
0
],
&
dilationA
[
0
],
&
mode
,
&
computeType
));
return
std
::
string
(
"-mode "
)
+
(((
int
)
mode
==
0
)
?
"conv"
:
"corr"
)
+
" -padA"
+
shape
(
padA
,
nDim
)
+
" -convStrideA "
+
shape
(
strideA
,
nDim
)
+
" -dilationA "
+
shape
(
dilationA
,
nDim
);
}
static
bool
all_aligned
(
cudnnDataType_t
type
,
void
*
in
,
void
*
out
,
void
*
filter
)
{
size_t
alignMask
=
(
type
==
CUDNN_DATA_HALF
)
?
0x7F
:
0xFF
;
// there have to be entries for both aligned and not
if
(((
size_t
)
in
|
(
size_t
)
out
|
(
size_t
)
filter
)
&
alignMask
)
{
return
false
;
}
return
true
;
}
static
std
::
string
dnn_conv_shape
(
cudnnTensorDescriptor_t
inputDesc
,
PyGpuArrayObject
*
input
,
cudnnFilterDescriptor_t
filterDesc
,
PyGpuArrayObject
*
filter
,
cudnnConvolutionDescriptor_t
convDesc
,
PyGpuArrayObject
*
output
,
int
groups
)
{
cudnnDataType_t
dType
;
std
::
stringstream
s
;
int
expected_output_dims
[
5
]
=
{
0
};
cudnnStatus_t
err
=
cudnnGetConvolutionNdForwardOutputDim
(
convDesc
,
inputDesc
,
filterDesc
,
PyGpuArray_NDIM
(
filter
),
expected_output_dims
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error computing convolution output dim: %s"
,
cudnnGetErrorString
(
err
));
return
""
;
}
if
(
PyGpuArray_NDIM
(
filter
)
==
4
)
{
if
((
PyGpuArray_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
PyGpuArray_DIMS
(
output
)[
1
]
/
groups
!=
expected_output_dims
[
1
])
||
(
PyGpuArray_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
PyGpuArray_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ld"
" but received gradient with shape %dx%dx% dx%d"
,
expected_output_dims
[
0
],
expected_output_dims
[
1
]
/
groups
,
expected_output_dims
[
2
],
expected_output_dims
[
3
],
PyGpuArray_DIMS
(
output
)[
0
],
PyGpuArray_DIMS
(
output
)[
1
],
PyGpuArray_DIMS
(
output
)[
2
],
PyGpuArray_DIMS
(
output
)[
3
]);
return
""
;
}
}
else
if
(
PyGpuArray_NDIM
(
filter
)
==
5
)
{
if
((
PyGpuArray_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
PyGpuArray_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
PyGpuArray_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
PyGpuArray_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
])
||
(
PyGpuArray_DIMS
(
output
)[
4
]
!=
expected_output_dims
[
4
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ldx%ld"
" but received gradient with shape %ldx%ldx%ldx%ldx%ld"
,
expected_output_dims
[
0
],
expected_output_dims
[
1
],
expected_output_dims
[
2
],
expected_output_dims
[
3
],
expected_output_dims
[
4
],
PyGpuArray_DIMS
(
output
)[
0
],
PyGpuArray_DIMS
(
output
)[
1
],
PyGpuArray_DIMS
(
output
)[
2
],
PyGpuArray_DIMS
(
output
)[
3
],
PyGpuArray_DIMS
(
output
)[
4
]);
return
""
;
}
}
s
<<
"-g"
<<
groups
<<
" -dimA"
<<
shape
(
inputDesc
)
<<
" -filtA"
<<
shape
(
filterDesc
,
&
dType
)
<<
shape
(
convDesc
);
// there have to be entries for both aligned and not
if
(
!
all_aligned
(
dType
,
PyGpuArray_DEV_DATA
(
input
),
PyGpuArray_DEV_DATA
(
output
),
PyGpuArray_DEV_DATA
(
filter
)))
{
s
<<
" [unaligned] "
;
}
return
std
::
string
(
s
.
str
().
c_str
());
}
static
void
dnn_conv_update_cache
(
const
std
::
string
&
hash
,
const
AlgoRec
&
rec
)
{
pthread_mutex_lock
(
&
algoMutex
);
algoCache
[
hash
]
=
rec
;
pthread_mutex_unlock
(
&
algoMutex
);
}
static
const
AlgoRec
*
dnn_conv_check_cache
(
const
std
::
string
&
hash
)
{
pthread_mutex_lock
(
&
algoMutex
);
bool
cacheHit
=
false
;
const
AlgoRec
*
ret
=
0
;
// cout << "dnn_conv_check_cache: "<< hash << endl;
AlgoCache
::
iterator
hit
=
algoCache
.
find
(
hash
);
if
(
hit
!=
algoCache
.
end
())
ret
=
&
hit
->
second
;
pthread_mutex_unlock
(
&
algoMutex
);
return
ret
;
}
theano/gpuarray/c_code/dnn_conv_find.c
浏览文件 @
46e0dbdd
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <sstream>
#include <sstream>
#include <vector>
#include <vector>
#include <string>
#include <string>
#include "dnn_conv_find.h"
#if __cplusplus < 201103L
#if __cplusplus < 201103L
#include <tr1/unordered_map>
#include <tr1/unordered_map>
typedef
std
::
tr1
::
unordered_map
<
std
::
string
,
AlgoRec
>
AlgoCache
;
typedef
std
::
tr1
::
unordered_map
<
std
::
string
,
AlgoRec
>
AlgoCache
;
...
...
theano/gpuarray/c_code/dnn_conv_find.h
浏览文件 @
46e0dbdd
#pragma once
#pragma once
#include <string>
#include <cuda.h>
#include <cuda.h>
#include <cudnn.h>
#include <cudnn.h>
#if CUDNN_MAJOR < 7
enum
cudnnMathType_t
{
CUDNN_DEFAULT_MATH
=
0
,
CUDNN_TENSOR_OP_MATH
=
1
};
#endif
inline
cudnnStatus_t
checkCudnnStatus
(
cudnnStatus_t
err
)
{
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"CUDNN Error: %s"
,
cudnnGetErrorString
(
err
));
}
return
err
;
}
/* a common struct for all 3 CUDNN enums */
struct
AlgoRec
{
int
algo
;
cudnnDataType_t
dataType
;
size_t
wsSize
;
cudnnMathType_t
mathType
;
};
theano/gpuarray/c_code/dnn_fwd.c
浏览文件 @
46e0dbdd
...
@@ -3,18 +3,14 @@ reuse_algo = 0;
...
@@ -3,18 +3,14 @@ reuse_algo = 0;
prev_algo
.
algo
=
PARAMS
->
conv_algo
;
prev_algo
.
algo
=
PARAMS
->
conv_algo
;
prev_algo
.
mathType
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
mathType
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
dataType
=
CUDNN_DATA_FLOAT
;
prev_algo
.
dataType
=
CUDNN_DATA_FLOAT
;
hash_prefix
=
std
::
string
(
"FW| GPU#"
);
memset
(
prev_img_dims
,
0
,
sizeof
(
prev_img_dims
));
memset
(
prev_kern_dims
,
0
,
sizeof
(
prev_kern_dims
));
#section support_code_struct
#section support_code_struct
#line 12 "dnn_fwd.c"
#line 12 "dnn_fwd.c"
#include "dnn_conv_find.h"
int
reuse_algo
;
int
reuse_algo
;
bool
use_cached
;
bool
use_cached
;
AlgoRec
prev_algo
;
AlgoRec
prev_algo
;
size_t
prev_img_dims
[
5
];
std
::
string
hash_prefix
;
size_t
prev_kern_dims
[
5
];
int
int
APPLY_SPECIFIC
(
conv_fwd
)(
PyGpuArrayObject
*
input
,
PyGpuArrayObject
*
kerns
,
APPLY_SPECIFIC
(
conv_fwd
)(
PyGpuArrayObject
*
input
,
PyGpuArrayObject
*
kerns
,
...
@@ -100,19 +96,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -100,19 +96,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
char
algorithm_name
[
128
];
char
algorithm_name
[
128
];
#endif
#endif
size_t
free
=
c_get_largest_free_block_size
(
c
);
cuda_enter
(
c
->
ctx
);
cuda_enter
(
c
->
ctx
);
if
(
params
->
choose_algo
)
{
if
(
params
->
choose_algo
)
{
if
(
!
params
->
choose_once
)
{
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
input
,
i
)
==
prev_img_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
kerns
,
i
)
==
prev_kern_dims
[
i
]);
}
}
if
(
!
reuse_algo
)
{
if
(
!
reuse_algo
)
{
char
pci_id
[
16
];
char
pci_id
[
16
];
...
@@ -120,7 +108,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -120,7 +108,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
hashkey
=
dnn_conv_shape
(
APPLY_SPECIFIC
(
input
),
input
,
APPLY_SPECIFIC
(
kerns
),
kerns
,
desc
,
*
output
,
groups
);
hashkey
=
dnn_conv_shape
(
APPLY_SPECIFIC
(
input
),
input
,
APPLY_SPECIFIC
(
kerns
),
kerns
,
desc
,
*
output
,
groups
);
if
(
hashkey
.
empty
())
if
(
hashkey
.
empty
())
return
1
;
return
1
;
hashkey
=
std
::
string
(
"F| GPU#"
)
+
pci_id
+
hashkey
;
hashkey
=
hash_prefix
+
pci_id
+
hashkey
;
// check out cache
// check out cache
const
AlgoRec
*
cached
=
dnn_conv_check_cache
(
hashkey
);
const
AlgoRec
*
cached
=
dnn_conv_check_cache
(
hashkey
);
if
(
cached
)
{
if
(
cached
)
{
...
@@ -129,19 +117,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -129,19 +117,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
}
}
}
}
if
(
!
(
reuse_algo
||
use_cached
))
{
if
(
reuse_algo
||
use_cached
)
{
size_t
free
;
algo
=
(
cudnnConvolutionFwdAlgo_t
)
prev_algo
.
algo
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
worksize
=
prev_algo
.
wsSize
;
if
(
err2
!=
GA_NO_ERROR
)
{
mathtype
=
prev_algo
.
mathType
;
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
}
else
{
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
if
(
params
->
choose_time
)
{
if
(
params
->
choose_time
)
{
int
count
;
int
count
;
cudnnConvolutionFwdAlgoPerf_t
choice
;
cudnnConvolutionFwdAlgoPerf_t
choice
;
...
@@ -175,8 +155,9 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -175,8 +155,9 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
algo
=
choice
.
algo
;
algo
=
choice
.
algo
;
prev_algo
.
algo
=
(
int
)
algo
;
prev_algo
.
algo
=
(
int
)
algo
;
prev_algo
.
wsSize
=
worksize
=
choice
.
memory
;
prev_algo
.
wsSize
=
worksize
=
choice
.
memory
;
#if CUDNN_MAJOR >= 7
prev_algo
.
mathType
=
mathtype
=
choice
.
mathType
;
prev_algo
.
mathType
=
mathtype
=
choice
.
mathType
;
#endif
// Add to the cache
// Add to the cache
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
...
@@ -209,90 +190,8 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -209,90 +190,8 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
prev_algo
.
mathType
=
mathtype
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
mathType
=
mathtype
=
CUDNN_DEFAULT_MATH
;
// fprintf(stderr, "(cudnnGetConvolutionForwardAlgorithm: (err:%d), algo: %d\n", err, algo);
// fprintf(stderr, "(cudnnGetConvolutionForwardAlgorithm: (err:%d), algo: %d\n", err, algo);
}
}
}
else
{
algo
=
(
cudnnConvolutionFwdAlgo_t
)
prev_algo
.
algo
;
worksize
=
prev_algo
.
wsSize
;
mathtype
=
prev_algo
.
mathType
;
}
}
else
{
/* choose_algo */
/* Only these algos are supported for 3d conv with cuDNN >= V5.1. */
if
(
PyGpuArray_NDIM
(
input
)
==
5
&&
!
(
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
||
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
||
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
))
{
#ifdef DEBUG
if
(
0
!=
theano_enum_to_string_cudnnConvolutionFwdAlgo_t
(
algo
,
algorithm_name
))
return
1
;
fprintf
(
stderr
,
"(%s unsupported for 3D: fallback to CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM)
\n
"
,
algorithm_name
);
#endif
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
}
// Algo `small` does not work for a batch size > 2^16, with cuDNN >= V5.1.
// Issue should be resolved for cuDNN > V6.0.
// NB: In cuDNN V7, issue is resolved for 2D convolutionss only.
if
((
cudnnGetVersion
()
<
6100
||
PyGpuArray_NDIM
(
input
)
==
5
)
&&
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
&&
PyGpuArray_DIM
(
input
,
0
)
>
65536
)
{
#ifdef DEBUG
fprintf
(
stderr
,
"(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM "
"will fail with batch size > 2^16, fallback to CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM)
\n
"
);
#endif
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
}
// The FFT implementation does not support strides, 1x1 filters or inputs
// with a spatial dimension larger than 1024. The tiled-FFT implementation
// does not support strides.
// If the chosen implementation is FFT or tiled-FFT, validate that it can
// be used on the current data and default to a safe implementation if it
// can't.
// The following code is 2d-specific but it is fine as FFT and tiled-FFT are
// defined only for 2d filters
/* NB:
TODO: These checkings seems outdated for FFT algorithms with cuDNN >= 5.1.
New conditions apply and may depend on number of dimensions (2D or 3D)
e.g. for FFT_TILING.
TODO: More globally, how to handle CUDNN_STATUS_NOT_SUPPORTED with unsupported algorithms?
*/
if
((
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_FFT
||
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
)
&&
PyGpuArray_NDIM
(
input
)
==
4
)
{
// Extract the properties of the convolution descriptor
int
nd
;
int
pad
[
2
];
int
stride
[
2
];
int
dilation
[
2
];
cudnnConvolutionMode_t
mode
;
cudnnDataType_t
data_type
;
err
=
cudnnGetConvolutionNdDescriptor
(
desc
,
2
,
&
nd
,
pad
,
stride
,
dilation
,
&
mode
,
&
data_type
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error getting convolution properties: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
if
(
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_FFT
)
{
if
(
stride
[
0
]
!=
1
||
stride
[
1
]
!=
1
||
PyGpuArray_DIM
(
input
,
2
)
>
1024
||
PyGpuArray_DIM
(
input
,
3
)
>
1024
||
(
PyGpuArray_DIM
(
kerns
,
2
)
==
1
&&
PyGpuArray_DIM
(
kerns
,
3
)
==
1
))
{
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
}
}
else
{
// algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
if
(
stride
[
0
]
!=
1
||
stride
[
1
]
!=
1
)
{
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
}
}
}
}
}
}
/* choose_algo */
// if FindEx was used (choose_time), workspace size is set.
// if FindEx was used (choose_time), workspace size is set.
if
(
!
(
reuse_algo
||
use_cached
||
params
->
choose_time
))
if
(
!
(
reuse_algo
||
use_cached
||
params
->
choose_time
))
...
@@ -304,18 +203,16 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -304,18 +203,16 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
APPLY_SPECIFIC
(
output
),
APPLY_SPECIFIC
(
output
),
algo
,
algo
,
&
worksize
);
&
worksize
);
if
(
err
==
CUDNN_STATUS_NOT_SUPPORTED
)
{
if
(
err
==
CUDNN_STATUS_NOT_SUPPORTED
)
{
// Fallback to none algo if not supported
// Fallback to none algo if not supported
#ifdef DEBUG
#ifdef DEBUG
if
(
0
!=
theano_enum_to_string_cudnnConvolutionFwdAlgo_t
(
algo
,
algorithm_name
))
if
(
0
!=
theano_enum_to_string_cudnnConvolutionFwdAlgo_t
(
algo
,
algorithm_name
))
return
1
;
return
1
;
fprintf
(
stderr
,
"(%s error getting worksize: "
fprintf
(
stderr
,
"(%s error getting worksize: "
"fallback to CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM)
\n
"
,
algorithm_name
);
"fallback to CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM)
\n
"
,
algorithm_name
);
#endif
#endif
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
}
err
=
cudnnGetConvolutionForwardWorkspaceSize
(
params
->
handle
,
err
=
cudnnGetConvolutionForwardWorkspaceSize
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
input
),
...
@@ -339,7 +236,6 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -339,7 +236,6 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
// Add to the cache
// Add to the cache
if
(
params
->
choose_algo
)
if
(
params
->
choose_algo
)
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
}
#ifdef DEBUG
#ifdef DEBUG
if
(
params
->
choose_algo
)
{
if
(
params
->
choose_algo
)
{
...
@@ -358,11 +254,6 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -358,11 +254,6 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
if
(
params
->
choose_once
)
{
if
(
params
->
choose_once
)
{
reuse_algo
=
1
;
reuse_algo
=
1
;
}
else
{
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
prev_img_dims
[
i
]
=
PyGpuArray_DIM
(
input
,
i
);
prev_kern_dims
[
i
]
=
PyGpuArray_DIM
(
kerns
,
i
);
}
}
}
{
{
...
...
theano/gpuarray/c_code/dnn_gi.c
浏览文件 @
46e0dbdd
...
@@ -3,17 +3,14 @@ prev_algo.algo = PARAMS->conv_algo;
...
@@ -3,17 +3,14 @@ prev_algo.algo = PARAMS->conv_algo;
prev_algo
.
mathType
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
mathType
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
dataType
=
CUDNN_DATA_FLOAT
;
prev_algo
.
dataType
=
CUDNN_DATA_FLOAT
;
reuse_algo
=
0
;
reuse_algo
=
0
;
memset
(
prev_kern_dims
,
0
,
sizeof
(
prev_kern_dims
));
hash_prefix
=
std
::
string
(
"GI| GPU#"
);
memset
(
prev_top_dims
,
0
,
sizeof
(
prev_top_dims
));
#section support_code_struct
#section support_code_struct
#include "dnn_conv_find.h"
#line 12 "dnn_gi.c"
#line 12 "dnn_gi.c"
int
reuse_algo
;
int
reuse_algo
;
bool
use_cached
;
bool
use_cached
;
AlgoRec
prev_algo
;
AlgoRec
prev_algo
;
size_t
prev_kern_dims
[
5
];
std
::
string
hash_prefix
;
size_t
prev_top_dims
[
5
];
int
int
APPLY_SPECIFIC
(
conv_gi
)(
PyGpuArrayObject
*
kerns
,
PyGpuArrayObject
*
output
,
APPLY_SPECIFIC
(
conv_gi
)(
PyGpuArrayObject
*
kerns
,
PyGpuArrayObject
*
output
,
...
@@ -98,27 +95,14 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -98,27 +95,14 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
std
::
string
hashkey
;
std
::
string
hashkey
;
if
(
params
->
choose_algo
)
{
if
(
params
->
choose_algo
&&
!
reuse_algo
)
{
if
(
!
params
->
choose_once
)
{
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
kerns
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
kerns
,
i
)
==
prev_kern_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
output
,
i
)
==
prev_top_dims
[
i
]);
}
}
if
(
!
reuse_algo
)
{
char
pci_id
[
16
];
char
pci_id
[
16
];
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_PCIBUSID
,
pci_id
);
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_PCIBUSID
,
pci_id
);
// check out cache
// check out cache
hashkey
+
=
dnn_conv_shape
(
APPLY_SPECIFIC
(
input
),
*
input
,
APPLY_SPECIFIC
(
kerns
),
kerns
,
desc
,
output
,
groups
);
hashkey
=
dnn_conv_shape
(
APPLY_SPECIFIC
(
input
),
*
input
,
APPLY_SPECIFIC
(
kerns
),
kerns
,
desc
,
output
,
groups
);
if
(
hashkey
.
empty
())
if
(
hashkey
.
empty
())
return
1
;
return
1
;
hashkey
=
std
::
string
(
"GI| GPU#"
)
+
pci_id
+
hashkey
;
hashkey
=
hash_prefix
+
pci_id
+
hashkey
;
const
AlgoRec
*
cached
=
dnn_conv_check_cache
(
hashkey
);
const
AlgoRec
*
cached
=
dnn_conv_check_cache
(
hashkey
);
if
(
cached
)
{
if
(
cached
)
{
prev_algo
=
*
cached
;
prev_algo
=
*
cached
;
...
@@ -126,22 +110,11 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -126,22 +110,11 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
}
}
}
}
cuda_enter
(
c
->
ctx
);
size_t
free
=
c_get_largest_free_block_size
(
c
);
if
(
!
(
reuse_algo
||
use_cached
))
{
size_t
free
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
err2
!=
GA_NO_ERROR
)
{
cuda_enter
(
c
->
ctx
);
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
if
(
params
->
choose_algo
&&
!
(
reuse_algo
||
use_cached
))
{
if
(
params
->
choose_time
)
{
if
(
params
->
choose_time
)
{
int
count
;
int
count
;
cudnnConvolutionBwdDataAlgoPerf_t
choice
;
cudnnConvolutionBwdDataAlgoPerf_t
choice
;
...
@@ -170,8 +143,9 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -170,8 +143,9 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
algo
=
choice
.
algo
;
algo
=
choice
.
algo
;
prev_algo
.
algo
=
(
int
)
algo
;
prev_algo
.
algo
=
(
int
)
algo
;
prev_algo
.
wsSize
=
worksize
=
choice
.
memory
;
prev_algo
.
wsSize
=
worksize
=
choice
.
memory
;
#if CUDNN_MAJOR >= 7
prev_algo
.
mathType
=
mathtype
=
choice
.
mathType
;
prev_algo
.
mathType
=
mathtype
=
choice
.
mathType
;
#endif
// Add to the cache
// Add to the cache
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
...
@@ -203,7 +177,18 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -203,7 +177,18 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
prev_algo
.
mathType
=
mathtype
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
mathType
=
mathtype
=
CUDNN_DEFAULT_MATH
;
}
}
}
}
}
else
{
/*choose_algo */
// if FindEx was used (choose_time), workspace size is set.
if
(
!
(
reuse_algo
||
use_cached
||
params
->
choose_time
))
{
err
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
params
->
handle
,
APPLY_SPECIFIC
(
kerns
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
input
),
algo
,
&
worksize
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error getting worksize: %s"
,
cudnnGetErrorString
(
err
));
// The FFT implementation does not support strides, 1x1 filters or inputs
// The FFT implementation does not support strides, 1x1 filters or inputs
// with a spatial dimension larger than 1024. The tiled-FFT implementation
// with a spatial dimension larger than 1024. The tiled-FFT implementation
...
@@ -251,19 +236,12 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -251,19 +236,12 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
}
}
}
}
}
}
}
/* choose_algo */
// if FindEx was used (choose_time), workspace size is set.
if
(
!
(
reuse_algo
||
use_cached
||
params
->
choose_time
))
{
err
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
err
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
params
->
handle
,
APPLY_SPECIFIC
(
kerns
),
APPLY_SPECIFIC
(
output
),
desc
,
params
->
handle
,
APPLY_SPECIFIC
(
kerns
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
input
),
algo
,
&
worksize
);
APPLY_SPECIFIC
(
input
),
algo
,
&
worksize
);
}
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error getting worksize: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
cuda_exit
(
c
->
ctx
);
return
1
;
return
1
;
}
}
...
@@ -273,26 +251,26 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -273,26 +251,26 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
// Add to the cache
// Add to the cache
if
(
params
->
choose_algo
)
if
(
params
->
choose_algo
)
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
}
}
// !(reuse_algo || use_cached || params->choose_time)
#ifdef DEBUG
#ifdef DEBUG
char
algorithm_name
[
128
];
if
(
params
->
choose_algo
)
{
if
(
0
!=
theano_enum_to_string_cudnnConvolutionBwdDataAlgo_t
(
algo
,
algorithm_name
))
if
(
0
!=
theano_enum_to_string_cudnnConvolutionBwdDataAlgo_t
(
algo
,
algorithm_name
))
return
1
;
return
1
;
// NB: This is printed only when algorithm is chosen at runtime.
// NB: This is printed only when algorithm is chosen at runtime.
if
(
reuse_algo
)
fprintf
(
stderr
,
"%s%s algo: %d %s%s ws: %ld, tensor: %d hash:%s
\n
"
,
fprintf
(
stderr
,
"(reused %s)
\n
"
,
algorithm_name
);
params
->
choose_algo
?
"[A]"
:
""
,
else
params
->
choose_time
?
"[T]"
:
""
,
fprintf
(
stderr
,
"(using %s)
\n
"
,
algorithm_name
);
algo
,
// algorithm_name,
#endif
reuse_algo
?
"(reused)"
:
""
,
use_cached
?
"(cache)"
:
""
,
worksize
,
mathtype
,
hashkey
.
c_str
()
);
}
#endif
if
(
params
->
choose_once
)
{
if
(
params
->
choose_once
)
{
reuse_algo
=
1
;
reuse_algo
=
1
;
}
else
{
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
kerns
);
i
++
)
{
prev_kern_dims
[
i
]
=
PyGpuArray_DIM
(
kerns
,
i
);
prev_top_dims
[
i
]
=
PyGpuArray_DIM
(
output
,
i
);
}
}
}
gpudata
*
workspace
=
0
;
gpudata
*
workspace
=
0
;
...
...
theano/gpuarray/c_code/dnn_gw.c
浏览文件 @
46e0dbdd
...
@@ -3,17 +3,15 @@ prev_algo.algo = PARAMS->conv_algo;
...
@@ -3,17 +3,15 @@ prev_algo.algo = PARAMS->conv_algo;
prev_algo
.
mathType
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
mathType
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
dataType
=
CUDNN_DATA_FLOAT
;
prev_algo
.
dataType
=
CUDNN_DATA_FLOAT
;
reuse_algo
=
0
;
reuse_algo
=
0
;
memset
(
prev_img_dims
,
0
,
sizeof
(
prev_img_dims
));
hash_prefix
=
std
::
string
(
"GW| GPU#"
);
memset
(
prev_top_dims
,
0
,
sizeof
(
prev_top_dims
));
#section support_code_struct
#section support_code_struct
#line 11 "dnn_gw.c"
#line 11 "dnn_gw.c"
#include "dnn_conv_find.h"
int
reuse_algo
;
int
reuse_algo
;
bool
use_cached
;
bool
use_cached
;
AlgoRec
prev_algo
;
AlgoRec
prev_algo
;
size_t
prev_img_dims
[
5
];
std
::
string
hash_prefix
;
size_t
prev_top_dims
[
5
];
int
int
APPLY_SPECIFIC
(
conv_gw
)(
PyGpuArrayObject
*
input
,
PyGpuArrayObject
*
output
,
APPLY_SPECIFIC
(
conv_gw
)(
PyGpuArrayObject
*
input
,
PyGpuArrayObject
*
output
,
...
@@ -95,20 +93,13 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -95,20 +93,13 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
#endif
#endif
size_t
worksize
=
0
;
size_t
worksize
=
0
;
cudnnMathType_t
mathtype
=
CUDNN_DEFAULT_MATH
;
cudnnMathType_t
mathtype
=
CUDNN_DEFAULT_MATH
;
std
::
string
hashkey
;
std
::
string
hashkey
;
size_t
free
=
c_get_largest_free_block_size
(
c
);
cuda_enter
(
c
->
ctx
);
cuda_enter
(
c
->
ctx
);
if
(
params
->
choose_algo
)
{
if
(
params
->
choose_algo
)
{
if
(
!
params
->
choose_once
)
{
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
input
,
i
)
==
prev_img_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
output
,
i
)
==
prev_top_dims
[
i
]);
}
}
if
(
!
reuse_algo
)
{
if
(
!
reuse_algo
)
{
char
pci_id
[
16
];
char
pci_id
[
16
];
...
@@ -116,7 +107,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -116,7 +107,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
hashkey
=
dnn_conv_shape
(
APPLY_SPECIFIC
(
input
),
input
,
APPLY_SPECIFIC
(
kerns
),
*
kerns
,
desc
,
output
,
groups
);
hashkey
=
dnn_conv_shape
(
APPLY_SPECIFIC
(
input
),
input
,
APPLY_SPECIFIC
(
kerns
),
*
kerns
,
desc
,
output
,
groups
);
if
(
hashkey
.
empty
())
if
(
hashkey
.
empty
())
return
1
;
return
1
;
hashkey
=
std
::
string
(
"GW| GPU#"
)
+
pci_id
+
hashkey
;
hashkey
=
hash_prefix
+
pci_id
+
hashkey
;
// check out cache
// check out cache
const
AlgoRec
*
cached
=
dnn_conv_check_cache
(
hashkey
);
const
AlgoRec
*
cached
=
dnn_conv_check_cache
(
hashkey
);
if
(
cached
)
{
if
(
cached
)
{
...
@@ -125,20 +116,11 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -125,20 +116,11 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
}
}
}
}
if
(
!
(
reuse_algo
||
use_cached
))
{
if
(
reuse_algo
||
use_cached
)
{
algo
=
(
cudnnConvolutionBwdFilterAlgo_t
)
prev_algo
.
algo
;
size_t
free
;
worksize
=
prev_algo
.
wsSize
;
mathtype
=
prev_algo
.
mathType
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
}
else
{
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
if
(
params
->
choose_time
)
{
if
(
params
->
choose_time
)
{
int
count
;
int
count
;
...
@@ -169,8 +151,9 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -169,8 +151,9 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
algo
=
choice
.
algo
;
algo
=
choice
.
algo
;
prev_algo
.
algo
=
(
int
)
algo
;
prev_algo
.
algo
=
(
int
)
algo
;
prev_algo
.
wsSize
=
worksize
=
choice
.
memory
;
prev_algo
.
wsSize
=
worksize
=
choice
.
memory
;
#if CUDNN_MAJOR >= 7
prev_algo
.
mathType
=
mathtype
=
choice
.
mathType
;
prev_algo
.
mathType
=
mathtype
=
choice
.
mathType
;
#endif
// Add to the cache
// Add to the cache
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
dnn_conv_update_cache
(
hashkey
,
prev_algo
);
...
@@ -202,45 +185,8 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -202,45 +185,8 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
// no tensor_op returned from Get()
// no tensor_op returned from Get()
prev_algo
.
mathType
=
mathtype
=
CUDNN_DEFAULT_MATH
;
prev_algo
.
mathType
=
mathtype
=
CUDNN_DEFAULT_MATH
;
}
}
}
else
{
algo
=
(
cudnnConvolutionBwdFilterAlgo_t
)
prev_algo
.
algo
;
worksize
=
prev_algo
.
wsSize
;
mathtype
=
prev_algo
.
mathType
;
}
}
else
{
// The FFT implementation does not support strides, 1x1 filters or inputs
// with a spatial dimension larger than 1024.
// If the chosen implementation is FFT, validate that it can
// be used on the current data and default to a safe implementation if it
// can't.
// The following code is 2d-specific but it is fine as FFT and tiled-FFT are
// defined only for 2d filters
if
(
algo
==
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT
&&
PyGpuArray_NDIM
(
input
)
==
4
)
{
// Extract the properties of the convolution descriptor
int
nd
;
int
pad
[
2
];
int
stride
[
2
];
int
upscale
[
2
];
cudnnConvolutionMode_t
mode
;
cudnnDataType_t
data_type
;
err
=
cudnnGetConvolutionNdDescriptor
(
desc
,
2
,
&
nd
,
pad
,
stride
,
upscale
,
&
mode
,
&
data_type
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error getting convolution properties: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
if
(
stride
[
0
]
!=
1
||
stride
[
1
]
!=
1
||
PyGpuArray_DIM
(
input
,
2
)
>
1024
||
PyGpuArray_DIM
(
input
,
3
)
>
1024
||
(
PyGpuArray_DIM
(
*
kerns
,
2
)
==
1
&&
PyGpuArray_DIM
(
*
kerns
,
3
)
==
1
))
{
algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0
;
}
}
}
}
/* choose_algo */
}
/* choose_algo */
// if FindEx was used (choose_time), workspace size is set.
// if FindEx was used (choose_time), workspace size is set.
if
(
!
(
reuse_algo
||
use_cached
||
params
->
choose_time
))
if
(
!
(
reuse_algo
||
use_cached
||
params
->
choose_time
))
...
@@ -251,11 +197,27 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -251,11 +197,27 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
APPLY_SPECIFIC
(
kerns
),
algo
,
&
worksize
);
APPLY_SPECIFIC
(
kerns
),
algo
,
&
worksize
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
#ifdef DEBUG
if
(
0
!=
theano_enum_to_string_cudnnConvolutionBwdFilterAlgo_t
(
algo
,
algorithm_name
))
return
1
;
fprintf
(
stderr
,
"(%s error getting worksize:%s, falling back to CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0"
,
algorithm_name
,
cudnnGetErrorString
(
err
));
#endif
algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0
;
err
=
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
kerns
),
algo
,
&
worksize
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error getting worksize: %s"
,
PyErr_Format
(
PyExc_RuntimeError
,
"error getting worksize: %s"
,
cudnnGetErrorString
(
err
));
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
cuda_exit
(
c
->
ctx
);
return
1
;
return
1
;
}
}
}
// save worksize for next time/cache
// save worksize for next time/cache
prev_algo
.
wsSize
=
worksize
;
prev_algo
.
wsSize
=
worksize
;
...
@@ -265,6 +227,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -265,6 +227,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
}
}
#ifdef DEBUG
#ifdef DEBUG
if
(
params
->
choose_algo
)
{
if
(
0
!=
theano_enum_to_string_cudnnConvolutionBwdFilterAlgo_t
(
algo
,
algorithm_name
))
if
(
0
!=
theano_enum_to_string_cudnnConvolutionBwdFilterAlgo_t
(
algo
,
algorithm_name
))
return
1
;
return
1
;
// NB: This is printed only when algorithm is chosen at runtime.
// NB: This is printed only when algorithm is chosen at runtime.
...
@@ -276,15 +239,11 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -276,15 +239,11 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
use_cached
?
"(cache)"
:
""
,
use_cached
?
"(cache)"
:
""
,
worksize
,
mathtype
,
hashkey
.
c_str
()
worksize
,
mathtype
,
hashkey
.
c_str
()
);
);
}
#endif
#endif
if
(
params
->
choose_once
)
{
if
(
params
->
choose_once
)
{
reuse_algo
=
1
;
reuse_algo
=
1
;
}
else
{
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
prev_img_dims
[
i
]
=
PyGpuArray_DIM
(
input
,
i
);
prev_top_dims
[
i
]
=
PyGpuArray_DIM
(
output
,
i
);
}
}
}
gpudata
*
workspace
=
0
;
gpudata
*
workspace
=
0
;
...
...
theano/gpuarray/dnn.py
浏览文件 @
46e0dbdd
...
@@ -567,7 +567,7 @@ class GpuDnnConv(DnnBase):
...
@@ -567,7 +567,7 @@ class GpuDnnConv(DnnBase):
num_groups
=
int_t
)
num_groups
=
int_t
)
def
__init__
(
self
,
algo
=
None
,
inplace
=
False
,
num_groups
=
1
):
def
__init__
(
self
,
algo
=
None
,
inplace
=
False
,
num_groups
=
1
):
DnnBase
.
__init__
(
self
,
[
"c_code/dnn_conv_base.c"
,
"c_code/dnn_
conv_find.c"
,
"c_code/dnn_
fwd.c"
],
DnnBase
.
__init__
(
self
,
[
"c_code/dnn_conv_base.c"
,
"c_code/dnn_fwd.c"
],
"APPLY_SPECIFIC(conv_fwd)"
)
"APPLY_SPECIFIC(conv_fwd)"
)
if
algo
is
None
:
if
algo
is
None
:
...
@@ -710,7 +710,7 @@ class GpuDnnConvGradW(DnnBase):
...
@@ -710,7 +710,7 @@ class GpuDnnConvGradW(DnnBase):
num_groups
=
int_t
)
num_groups
=
int_t
)
def
__init__
(
self
,
inplace
=
False
,
algo
=
None
,
num_groups
=
1
):
def
__init__
(
self
,
inplace
=
False
,
algo
=
None
,
num_groups
=
1
):
DnnBase
.
__init__
(
self
,
[
"c_code/dnn_conv_base.c"
,
"c_code/dnn_
conv_find.c"
,
"c_code/dnn_
gw.c"
],
DnnBase
.
__init__
(
self
,
[
"c_code/dnn_conv_base.c"
,
"c_code/dnn_gw.c"
],
"APPLY_SPECIFIC(conv_gw)"
)
"APPLY_SPECIFIC(conv_gw)"
)
self
.
inplace
=
bool
(
inplace
)
self
.
inplace
=
bool
(
inplace
)
if
self
.
inplace
:
if
self
.
inplace
:
...
@@ -846,7 +846,7 @@ class GpuDnnConvGradI(DnnBase):
...
@@ -846,7 +846,7 @@ class GpuDnnConvGradI(DnnBase):
num_groups
=
int_t
)
num_groups
=
int_t
)
def
__init__
(
self
,
inplace
=
False
,
algo
=
None
,
num_groups
=
1
):
def
__init__
(
self
,
inplace
=
False
,
algo
=
None
,
num_groups
=
1
):
DnnBase
.
__init__
(
self
,
[
"c_code/dnn_conv_base.c"
,
"c_code/dnn_
conv_find.c"
,
"c_code/dnn_
gi.c"
],
DnnBase
.
__init__
(
self
,
[
"c_code/dnn_conv_base.c"
,
"c_code/dnn_gi.c"
],
"APPLY_SPECIFIC(conv_gi)"
)
"APPLY_SPECIFIC(conv_gi)"
)
self
.
inplace
=
bool
(
inplace
)
self
.
inplace
=
bool
(
inplace
)
if
self
.
inplace
:
if
self
.
inplace
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论