Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8ee47394
提交
8ee47394
authored
7月 28, 2017
作者:
notoraptor
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow to concatenate conv case generators.
Update cases for gradweight ff_tiling. Allow DOUBLE_CONFIG for all fft_tiling computations in cuDNN V6.
上级
ac21919b
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
63 行增加
和
20 行删除
+63
-20
cudnn_defs.py
theano/gpuarray/cudnn_defs.py
+32
-8
check_dnn_conv.py
theano/gpuarray/tests/check_dnn_conv.py
+29
-10
run_dnn_conv.py
theano/gpuarray/tests/run_dnn_conv.py
+2
-2
没有找到文件。
theano/gpuarray/cudnn_defs.py
浏览文件 @
8ee47394
...
@@ -165,11 +165,6 @@ class CuDNNV51(object):
...
@@ -165,11 +165,6 @@ class CuDNNV51(object):
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
:
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
:
if
ndim
==
2
:
if
ndim
==
2
:
return
is_pseudo_half_config
(
dtype
,
precision
)
or
is_float_config
(
dtype
,
precision
)
return
is_pseudo_half_config
(
dtype
,
precision
)
or
is_float_config
(
dtype
,
precision
)
# NB: For cuDNN V6:
# " Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
# (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only when being in C code.
if
ndim
==
3
:
if
ndim
==
3
:
return
not
is_true_half_config
(
dtype
,
precision
)
return
not
is_true_half_config
(
dtype
,
precision
)
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD
:
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD
:
...
@@ -210,9 +205,6 @@ class CuDNNV51(object):
...
@@ -210,9 +205,6 @@ class CuDNNV51(object):
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
:
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
:
if
ndim
==
2
:
if
ndim
==
2
:
return
is_pseudo_half_config
(
dtype
,
precision
)
or
is_float_config
(
dtype
,
precision
)
return
is_pseudo_half_config
(
dtype
,
precision
)
or
is_float_config
(
dtype
,
precision
)
# NB: For cuDNN V6: "(DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only when being in C code.
if
ndim
==
3
:
if
ndim
==
3
:
return
not
is_true_half_config
(
dtype
,
precision
)
return
not
is_true_half_config
(
dtype
,
precision
)
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD
:
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD
:
...
@@ -265,6 +257,22 @@ class CuDNNV6(CuDNNV51):
...
@@ -265,6 +257,22 @@ class CuDNNV6(CuDNNV51):
(
'CUDNN_REDUCE_TENSOR_NORM2'
,
'norm2'
),
(
'CUDNN_REDUCE_TENSOR_NORM2'
,
'norm2'
),
ctype
=
'cudnnReduceTensorOp_t'
)
ctype
=
'cudnnReduceTensorOp_t'
)
def
fwd_algo_supports_dtype_config
(
self
,
algo
,
dtype
,
precision
,
ndim
):
is_supported
=
super
(
CuDNNV6
,
self
)
.
fwd_algo_supports_dtype_config
(
algo
,
dtype
,
precision
,
ndim
)
if
not
is_supported
:
algorithms
=
self
.
cudnnConvolutionFwdAlgo_t
algo
=
algorithms
.
fromalias
(
algo
)
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
:
# NB: For cuDNN V6:
# "Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
# (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only in C code. By default, let's allow DOUBLE_CONFIG.
return
ndim
==
2
and
(
is_pseudo_half_config
(
dtype
,
precision
)
or
is_float_config
(
dtype
,
precision
)
or
is_double_config
(
dtype
,
precision
))
return
is_supported
def
bwd_filter_algo_supports_dtype_config
(
self
,
algo
,
dtype
,
precision
,
ndim
):
def
bwd_filter_algo_supports_dtype_config
(
self
,
algo
,
dtype
,
precision
,
ndim
):
is_supported
=
super
(
CuDNNV6
,
self
)
.
bwd_filter_algo_supports_dtype_config
(
algo
,
dtype
,
precision
,
ndim
)
is_supported
=
super
(
CuDNNV6
,
self
)
.
bwd_filter_algo_supports_dtype_config
(
algo
,
dtype
,
precision
,
ndim
)
if
not
is_supported
:
if
not
is_supported
:
...
@@ -276,6 +284,22 @@ class CuDNNV6(CuDNNV51):
...
@@ -276,6 +284,22 @@ class CuDNNV6(CuDNNV51):
is_double_config
(
dtype
,
precision
))
is_double_config
(
dtype
,
precision
))
return
is_supported
return
is_supported
def
bwd_data_algo_supports_dtype_config
(
self
,
algo
,
dtype
,
precision
,
ndim
):
is_supported
=
super
(
CuDNNV6
,
self
)
.
bwd_data_algo_supports_dtype_config
(
algo
,
dtype
,
precision
,
ndim
)
if
not
is_supported
:
algorithms
=
self
.
cudnnConvolutionBwdDataAlgo_t
algo
=
algorithms
.
fromalias
(
algo
)
if
algo
==
algorithms
.
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
:
# NB: For cuDNN V6:
# "Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
# (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only in C code. By default, let's allow DOUBLE_CONFIG.
return
ndim
==
2
and
(
is_pseudo_half_config
(
dtype
,
precision
)
or
is_float_config
(
dtype
,
precision
)
or
is_double_config
(
dtype
,
precision
))
return
is_supported
class
CuDNNV7
(
CuDNNV6
):
class
CuDNNV7
(
CuDNNV6
):
version
=
7
version
=
7
...
...
theano/gpuarray/tests/check_dnn_conv.py
浏览文件 @
8ee47394
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
from
__future__
import
absolute_import
,
print_function
,
division
from
__future__
import
absolute_import
,
print_function
,
division
import
sys
import
sys
from
itertools
import
product
from
itertools
import
product
,
chain
import
nose
import
nose
import
numpy
as
np
import
numpy
as
np
...
@@ -316,6 +316,19 @@ class ConvCaseGenerator:
...
@@ -316,6 +316,19 @@ class ConvCaseGenerator:
all_border_modes
,
all_conv_modes
,
all_alphas
,
all_betas
))
all_border_modes
,
all_conv_modes
,
all_alphas
,
all_betas
))
class
ConvCaseGeneratorChain
:
"""
Help class to concatenate many conv case generators.
"""
def
__init__
(
self
,
*
conv_case_generators
):
assert
all
(
isinstance
(
g
,
ConvCaseGenerator
)
for
g
in
conv_case_generators
)
self
.
generators
=
conv_case_generators
def
get_cases
(
self
,
filter
=
None
):
return
chain
(
*
[
generator
.
get_cases
(
filter
)
for
generator
in
self
.
generators
])
class
CuDNNV51ConvCaseGenerator
(
object
):
class
CuDNNV51ConvCaseGenerator
(
object
):
"""
"""
Helper class to generate specific test cases for every algorithm supported by cuDNN V5.1.
Helper class to generate specific test cases for every algorithm supported by cuDNN V5.1.
...
@@ -430,14 +443,18 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
...
@@ -430,14 +443,18 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
def
_fwd_fft_tiling
(
self
,
ndim
):
def
_fwd_fft_tiling
(
self
,
ndim
):
if
ndim
==
2
:
if
ndim
==
2
:
filters_sizes
=
[(
32
,
5
),
(
256
,
1
),
(
10
,
10
),
(
5
,
1
)]
subsamples
=
[(
1
,
1
)]
subsamples
=
[(
1
,
1
)]
borders
=
[(
1
,
1
),
(
2
,
1
)]
# wDesc's filter height must be greater than convDesc's zero-padding height
return
ConvCaseGenerator
(
ndim
=
ndim
,
# wDesc's filter width must be greater than convDesc's zero-padding width
filters_sizes
=
filters_sizes
,
filters_sizes
=
[(
32
,
5
),
(
10
,
10
)]
subsamples
=
subsamples
,
borders
=
[(
1
,
1
),
(
6
,
4
)]
borders
=
borders
,
generator1
=
ConvCaseGenerator
(
ndim
=
ndim
,
dilations
=
self
.
_dilations
(
ndim
),
subsamples
=
subsamples
,
dilations
=
self
.
_dilations
(
ndim
))
filters_sizes
=
filters_sizes
,
borders
=
borders
)
filters_sizes
=
[(
256
,
1
),
(
5
,
1
)]
borders
=
[(
1
,
0
),
(
2
,
0
)]
generator2
=
ConvCaseGenerator
(
ndim
=
ndim
,
dilations
=
self
.
_dilations
(
ndim
),
subsamples
=
subsamples
,
filters_sizes
=
filters_sizes
,
borders
=
borders
)
return
ConvCaseGeneratorChain
(
generator1
,
generator2
)
if
ndim
==
3
:
if
ndim
==
3
:
return
super
(
CuDNNV6ConvCaseGenerator
,
self
)
.
_fwd_fft_tiling
(
ndim
)
return
super
(
CuDNNV6ConvCaseGenerator
,
self
)
.
_fwd_fft_tiling
(
ndim
)
...
@@ -445,10 +462,10 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
...
@@ -445,10 +462,10 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
return
self
.
_fwd_none
(
ndim
)
return
self
.
_fwd_none
(
ndim
)
def
_gw_fft_tiling
(
self
,
ndim
):
def
_gw_fft_tiling
(
self
,
ndim
):
inputs_sizes
=
[(
2
56
,
1
),
(
20
,
1
)]
inputs_sizes
=
[(
2
47
,
1
),
(
20
,
1
)]
filters_sizes
=
[(
3
,
1
),
(
10
,
1
)]
filters_sizes
=
[(
3
,
1
),
(
10
,
1
)]
subsamples
=
[(
1
,)
*
ndim
]
subsamples
=
[(
1
,)
*
ndim
]
borders
=
[(
1
,
1
),
(
2
,
1
)]
borders
=
[(
1
,
0
),
(
2
,
0
)]
return
ConvCaseGenerator
(
ndim
=
ndim
,
return
ConvCaseGenerator
(
ndim
=
ndim
,
inputs_sizes
=
inputs_sizes
,
inputs_sizes
=
inputs_sizes
,
filters_sizes
=
filters_sizes
,
filters_sizes
=
filters_sizes
,
...
@@ -467,6 +484,8 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
...
@@ -467,6 +484,8 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
def
gw
(
self
,
algo
,
ndim
):
def
gw
(
self
,
algo
,
ndim
):
if
algo
==
self
.
NONE
:
if
algo
==
self
.
NONE
:
return
self
.
_gw_none
(
ndim
)
return
self
.
_gw_none
(
ndim
)
if
algo
==
self
.
FFT_TILING
:
return
self
.
_gw_fft_tiling
(
ndim
)
return
super
(
CuDNNV6ConvCaseGenerator
,
self
)
.
gw
(
algo
,
ndim
)
return
super
(
CuDNNV6ConvCaseGenerator
,
self
)
.
gw
(
algo
,
ndim
)
def
gi
(
self
,
algo
,
ndim
):
def
gi
(
self
,
algo
,
ndim
):
...
...
theano/gpuarray/tests/run_dnn_conv.py
浏览文件 @
8ee47394
...
@@ -126,8 +126,8 @@ if args.algo not in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
...
@@ -126,8 +126,8 @@ if args.algo not in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
if
test
==
BWD_DATA
:
if
test
==
BWD_DATA
:
check_config
=
cudnn
.
bwd_data_algo_supports_dtype_config
(
args
.
algo
,
args
.
dtype
,
args
.
precision
,
ndim
)
check_config
=
cudnn
.
bwd_data_algo_supports_dtype_config
(
args
.
algo
,
args
.
dtype
,
args
.
precision
,
ndim
)
if
not
check_config
:
if
not
check_config
:
raise
ValueError
(
'
%
s computation does not
support configuration (
%
s,
%
s) for algo
%
s.'
%
(
print
(
'Warning:
%
s computation does not normally
support configuration (
%
s,
%
s) for algo
%
s.'
%
(
test
,
args
.
dtype
,
args
.
precision
,
args
.
algo
))
test
,
args
.
dtype
,
args
.
precision
,
args
.
algo
)
,
file
=
sys
.
stderr
)
algo
=
args
.
algo
algo
=
args
.
algo
dtype
=
args
.
dtype
dtype
=
args
.
dtype
precision
=
args
.
precision
precision
=
args
.
precision
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论