提交 a346913f authored 作者: notoraptor's avatar notoraptor

Skip runtime algos when there is no algo that supports

given data type configuration for given ndim.
上级 60fa63e1
...@@ -726,14 +726,17 @@ class BaseTestDnnConv(object): ...@@ -726,14 +726,17 @@ class BaseTestDnnConv(object):
def test_fwd(self): def test_fwd(self):
for dtype, precision in self.dtype_configs: for dtype, precision in self.dtype_configs:
algos = (algo for algo in self.fwd_algorithms algos = [algo for algo in self.fwd_algorithms
if cudnn.fwd_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.fwd_algo_supports_dtype_config(algo, dtype, precision, self.ndim)]
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim, dtype, precision).get_cases(): for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_fwd, algo, dtype, precision, parameters) yield (self.run_conv_fwd, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: if algos:
for parameters in self.get_cases(): # Some algorithms support current data type configuration for current ndim.
yield (self.run_conv_fwd, algo, dtype, precision, parameters) # So, an algorithm can be chosen at runtime.
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases():
yield (self.run_conv_fwd, algo, dtype, precision, parameters)
for dnn_case in self.special_cases: for dnn_case in self.special_cases:
if dnn_case.is_fwd(): if dnn_case.is_fwd():
if dnn_case.should_fail: if dnn_case.should_fail:
...@@ -743,14 +746,17 @@ class BaseTestDnnConv(object): ...@@ -743,14 +746,17 @@ class BaseTestDnnConv(object):
def test_gradinput(self): def test_gradinput(self):
for dtype, precision in self.dtype_configs: for dtype, precision in self.dtype_configs:
algos = (algo for algo in self.bwd_data_algorithms algos = [algo for algo in self.bwd_data_algorithms
if cudnn.bwd_data_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.bwd_data_algo_supports_dtype_config(algo, dtype, precision, self.ndim)]
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.gi(algo, self.ndim, dtype, precision).get_cases(): for parameters in cudnn_conv_case_generator.gi(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradinput, algo, dtype, precision, parameters) yield (self.run_conv_gradinput, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: if algos:
for parameters in self.get_cases(): # Some algorithms support current data type configuration for current ndim.
yield (self.run_conv_gradinput, algo, dtype, precision, parameters) # So, an algorithm can be chosen at runtime.
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases():
yield (self.run_conv_gradinput, algo, dtype, precision, parameters)
for dnn_case in self.special_cases: for dnn_case in self.special_cases:
if dnn_case.is_bwd_data(): if dnn_case.is_bwd_data():
if dnn_case.should_fail: if dnn_case.should_fail:
...@@ -760,14 +766,17 @@ class BaseTestDnnConv(object): ...@@ -760,14 +766,17 @@ class BaseTestDnnConv(object):
def test_gradweight(self): def test_gradweight(self):
for dtype, precision in self.dtype_configs: for dtype, precision in self.dtype_configs:
algos = (algo for algo in self.bwd_filter_algorithms algos = [algo for algo in self.bwd_filter_algorithms
if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim)]
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.gw(algo, self.ndim, dtype, precision).get_cases(): for parameters in cudnn_conv_case_generator.gw(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradweight, algo, dtype, precision, parameters) yield (self.run_conv_gradweight, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: if algos:
for parameters in self.get_cases(): # Some algorithms support current data type configuration for current ndim.
yield (self.run_conv_gradweight, algo, dtype, precision, parameters) # So, an algorithm can be chosen at runtime.
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases():
yield (self.run_conv_gradweight, algo, dtype, precision, parameters)
for dnn_case in self.special_cases: for dnn_case in self.special_cases:
if dnn_case.is_bwd_filter(): if dnn_case.is_bwd_filter():
if dnn_case.should_fail: if dnn_case.should_fail:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论