提交 51369fb4 authored 作者: sentient07's avatar sentient07

added fd parameter to tcase

上级 c89d9d6f
...@@ -21,12 +21,13 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d): ...@@ -21,12 +21,13 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
# provide_shape is not used by the cuDNN impementation # provide_shape is not used by the cuDNN impementation
self.provide_shape = [False] self.provide_shape = [False]
def tcase(self, i, f, s, b, flip, provide_shape): def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)):
if not dnn_available(test_ctx_name): if not dnn_available(test_ctx_name):
raise SkipTest(dnn_available.msg) raise SkipTest(dnn_available.msg)
mode = mode_with_gpu mode = mode_with_gpu
if fd != (1, 1):
o = self.get_output_shape(i, f, s, b, (1, 1)) raise SkipTest("Doesn't have CUDNN implementation")
o = self.get_output_shape(i, f, s, b, fd)
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s, self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
verify_grad=True, mode=mode, verify_grad=True, mode=mode,
provide_shape=provide_shape, border_mode=b, provide_shape=provide_shape, border_mode=b,
......
...@@ -296,7 +296,7 @@ class BaseTestConv2d(unittest.TestCase): ...@@ -296,7 +296,7 @@ class BaseTestConv2d(unittest.TestCase):
for b in self.border_modes: for b in self.border_modes:
try: try:
self.tcase(i, f, s, db, dflip, self.tcase(i, f, s, db, dflip,
dprovide_shape) dprovide_shape, fd)
except SkipTest as e: except SkipTest as e:
skipped = e skipped = e
for flip in self.filter_flip: for flip in self.filter_flip:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论