提交 a346913f authored 作者: notoraptor's avatar notoraptor

Skip runtime algos when there is no algo that supports

given data type configuration for given ndim.
上级 60fa63e1
...@@ -726,11 +726,14 @@ class BaseTestDnnConv(object): ...@@ -726,11 +726,14 @@ class BaseTestDnnConv(object):
def test_fwd(self): def test_fwd(self):
for dtype, precision in self.dtype_configs: for dtype, precision in self.dtype_configs:
algos = (algo for algo in self.fwd_algorithms algos = [algo for algo in self.fwd_algorithms
if cudnn.fwd_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.fwd_algo_supports_dtype_config(algo, dtype, precision, self.ndim)]
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim, dtype, precision).get_cases(): for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_fwd, algo, dtype, precision, parameters) yield (self.run_conv_fwd, algo, dtype, precision, parameters)
if algos:
# Some algorithms support current data type configuration for current ndim.
# So, an algorithm can be chosen at runtime.
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases(): for parameters in self.get_cases():
yield (self.run_conv_fwd, algo, dtype, precision, parameters) yield (self.run_conv_fwd, algo, dtype, precision, parameters)
...@@ -743,11 +746,14 @@ class BaseTestDnnConv(object): ...@@ -743,11 +746,14 @@ class BaseTestDnnConv(object):
def test_gradinput(self): def test_gradinput(self):
for dtype, precision in self.dtype_configs: for dtype, precision in self.dtype_configs:
algos = (algo for algo in self.bwd_data_algorithms algos = [algo for algo in self.bwd_data_algorithms
if cudnn.bwd_data_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.bwd_data_algo_supports_dtype_config(algo, dtype, precision, self.ndim)]
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.gi(algo, self.ndim, dtype, precision).get_cases(): for parameters in cudnn_conv_case_generator.gi(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradinput, algo, dtype, precision, parameters) yield (self.run_conv_gradinput, algo, dtype, precision, parameters)
if algos:
# Some algorithms support current data type configuration for current ndim.
# So, an algorithm can be chosen at runtime.
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases(): for parameters in self.get_cases():
yield (self.run_conv_gradinput, algo, dtype, precision, parameters) yield (self.run_conv_gradinput, algo, dtype, precision, parameters)
...@@ -760,11 +766,14 @@ class BaseTestDnnConv(object): ...@@ -760,11 +766,14 @@ class BaseTestDnnConv(object):
def test_gradweight(self): def test_gradweight(self):
for dtype, precision in self.dtype_configs: for dtype, precision in self.dtype_configs:
algos = (algo for algo in self.bwd_filter_algorithms algos = [algo for algo in self.bwd_filter_algorithms
if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim)]
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.gw(algo, self.ndim, dtype, precision).get_cases(): for parameters in cudnn_conv_case_generator.gw(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradweight, algo, dtype, precision, parameters) yield (self.run_conv_gradweight, algo, dtype, precision, parameters)
if algos:
# Some algorithms support current data type configuration for current ndim.
# So, an algorithm can be chosen at runtime.
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases(): for parameters in self.get_cases():
yield (self.run_conv_gradweight, algo, dtype, precision, parameters) yield (self.run_conv_gradweight, algo, dtype, precision, parameters)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论