提交 113772d1 authored 作者: Frederic's avatar Frederic

Allow tests outside this repo to reuse our tester

上级 96d55935
...@@ -188,7 +188,11 @@ def safe_make_node(op, *inputs): ...@@ -188,7 +188,11 @@ def safe_make_node(op, *inputs):
def makeTester(name, op, expected, checks=None, good=None, bad_build=None, def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
bad_runtime=None, grad=None, mode=None, grad_rtol=None, bad_runtime=None, grad=None, mode=None, grad_rtol=None,
eps=1e-10, skip=False, test_memmap=True): eps=1e-10, skip=False, test_memmap=True, check_name=True):
"""
:param check_name:
Use only for tester that aren't in Theano.
"""
if checks is None: if checks is None:
checks = {} checks = {}
if good is None: if good is None:
...@@ -206,12 +210,14 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -206,12 +210,14 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
_bad_build, _bad_runtime, _grad = bad_build, bad_runtime, grad _bad_build, _bad_runtime, _grad = bad_build, bad_runtime, grad
_mode, _grad_rtol, _eps, skip_ = mode, grad_rtol, eps, skip _mode, _grad_rtol, _eps, skip_ = mode, grad_rtol, eps, skip
_test_memmap = test_memmap _test_memmap = test_memmap
_check_name = check_name
class Checker(unittest.TestCase): class Checker(unittest.TestCase):
op = staticmethod(_op) op = staticmethod(_op)
expected = staticmethod(_expected) expected = staticmethod(_expected)
checks = _checks checks = _checks
check_name = _check_name
good = _good good = _good
bad_build = _bad_build bad_build = _bad_build
bad_runtime = _bad_runtime bad_runtime = _bad_runtime
...@@ -223,6 +229,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -223,6 +229,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
def setUp(self): def setUp(self):
# Verify that the test's name is correctly set. # Verify that the test's name is correctly set.
# Some tests reuse it outside this module. # Some tests reuse it outside this module.
if self.check_name:
eval(self.__class__.__module__ + '.' + self.__class__.__name__) eval(self.__class__.__module__ + '.' + self.__class__.__name__)
# We keep a list of temporary files created in add_memmap_values, # We keep a list of temporary files created in add_memmap_values,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论